code cleanup
This commit is contained in:
5
.claude/settings.local.json
Normal file
5
.claude/settings.local.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"permissions": {
|
||||
"defaultMode": "acceptEdits"
|
||||
}
|
||||
}
|
||||
121
CLAUDE.md
Normal file
121
CLAUDE.md
Normal file
@@ -0,0 +1,121 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Building and Running
|
||||
```bash
|
||||
# Build the main application
|
||||
go build -o api.exe cmd/api/main.go
|
||||
|
||||
# Build with hot reload (requires air)
|
||||
go install github.com/cosmtrek/air@latest
|
||||
air
|
||||
|
||||
# Run the built binary
|
||||
./api.exe
|
||||
|
||||
# Build migration utility
|
||||
go build -o acc-server-migration.exe cmd/migrate/main.go
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
# Run tests with verbose output
|
||||
go test -v ./...
|
||||
|
||||
# Run specific test package
|
||||
go test ./tests/unit/service/
|
||||
go test ./tests/unit/controller/
|
||||
go test ./tests/unit/repository/
|
||||
```
|
||||
|
||||
### Documentation
|
||||
```bash
|
||||
# Generate Swagger documentation (if swag is installed)
|
||||
swag init -g cmd/api/main.go
|
||||
```
|
||||
|
||||
### Setup and Configuration
|
||||
```bash
|
||||
# Generate security configuration (Windows PowerShell)
|
||||
.\scripts\generate-secrets.ps1
|
||||
|
||||
# Deploy (requires configuration)
|
||||
.\scripts\deploy.ps1
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
This is a Go-based web application for managing Assetto Corsa Competizione (ACC) dedicated servers on Windows. The architecture follows a layered approach:
|
||||
|
||||
### Core Layers
|
||||
- **cmd/**: Application entry points (main.go for API server, migrate/main.go for migrations)
|
||||
- **local/**: Core application code organized by architectural layer
|
||||
- **api/**: HTTP route definitions and API setup
|
||||
- **controller/**: HTTP request handlers (config, membership, server, service_control, steam_2fa, system)
|
||||
- **service/**: Business logic layer (server management, Steam integration, Windows services)
|
||||
- **repository/**: Data access layer with GORM ORM
|
||||
- **model/**: Data models and structures
|
||||
- **middleware/**: HTTP middleware (auth, security, logging)
|
||||
- **utl/**: Utilities organized by function (cache, command execution, JWT, logging, etc.)
|
||||
|
||||
### Key Components
|
||||
|
||||
#### Dependency Injection
|
||||
Uses `go.uber.org/dig` for dependency injection. Main dependencies are set up in `cmd/api/main.go`.
|
||||
|
||||
#### Database
|
||||
- SQLite with GORM ORM
|
||||
- Database migrations in `local/migrations/`
|
||||
- Models support UUID primary keys
|
||||
|
||||
#### Authentication & Security
|
||||
- JWT-based authentication with two token types (regular and "open")
|
||||
- Comprehensive security middleware stack including rate limiting, input sanitization, CORS
|
||||
- Encrypted credential storage for Steam integration
|
||||
|
||||
#### Server Management
|
||||
- Windows service integration via NSSM
|
||||
- Steam integration for server installation/updates via SteamCMD
|
||||
- Interactive command execution for Steam 2FA
|
||||
- Firewall management
|
||||
- Configuration file generation and management
|
||||
|
||||
#### Logging
|
||||
- Custom logging system with multiple levels (debug, info, warn, error)
|
||||
- Request logging middleware
|
||||
- Structured logging with categories
|
||||
|
||||
### Testing Structure
|
||||
- Unit tests in `tests/unit/` organized by layer (controller, service, repository)
|
||||
- Test helpers and mocks in `tests/` directory
|
||||
- Uses standard Go testing with mocks for external dependencies
|
||||
|
||||
### External Dependencies
|
||||
- **Fiber v2**: Web framework
|
||||
- **GORM**: ORM for database operations
|
||||
- **SteamCMD**: External tool for Steam server management (configured via STEAMCMD_PATH env var)
|
||||
- **NSSM**: Windows service management (configured via NSSM_PATH env var)
|
||||
|
||||
### Configuration
|
||||
- Environment variables for external tool paths and configuration
|
||||
- JWT secrets generated via setup scripts
|
||||
- CORS configuration with configurable allowed origins
|
||||
- Default port 3000 (configurable via PORT env var)
|
||||
|
||||
## Important Notes
|
||||
|
||||
### Windows-Specific Features
|
||||
This application is designed specifically for Windows and includes:
|
||||
- Windows service management integration
|
||||
- PowerShell script execution
|
||||
- Windows-specific path handling
|
||||
- Firewall rule management
|
||||
|
||||
### Steam Integration
|
||||
The Steam 2FA implementation (`local/controller/steam_2fa.go`, `local/model/steam_2fa.go`) provides interactive Steam authentication for automated server management.
|
||||
1986
frontend.md
Normal file
1986
frontend.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,6 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro
|
||||
serverRoutes.Get("/", auth.HasPermission(model.ServerView), ac.GetAll)
|
||||
serverRoutes.Get("/:id", auth.HasPermission(model.ServerView), ac.GetById)
|
||||
serverRoutes.Post("/", auth.HasPermission(model.ServerCreate), ac.CreateServer)
|
||||
serverRoutes.Put("/:id", auth.HasPermission(model.ServerUpdate), ac.UpdateServer)
|
||||
serverRoutes.Delete("/:id", auth.HasPermission(model.ServerDelete), ac.DeleteServer)
|
||||
|
||||
apiServerRoutes := routeGroups.Api.Group("/server")
|
||||
@@ -152,41 +151,6 @@ func (ac *ServerController) CreateServer(c *fiber.Ctx) error {
|
||||
return c.JSON(server)
|
||||
}
|
||||
|
||||
// UpdateServer updates an existing server
|
||||
// @Summary Update an ACC server
|
||||
// @Description Update configuration for an existing ACC server
|
||||
// @Tags Server
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Server ID (UUID format)"
|
||||
// @Param server body model.Server true "Updated server configuration"
|
||||
// @Success 200 {object} object "Updated server details"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid server data or ID"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /server/{id} [put]
|
||||
func (ac *ServerController) UpdateServer(c *fiber.Ctx) error {
|
||||
serverIDStr := c.Params("id")
|
||||
serverID, err := uuid.Parse(serverIDStr)
|
||||
if err != nil {
|
||||
return ac.errorHandler.HandleUUIDError(c, "server ID")
|
||||
}
|
||||
|
||||
server := new(model.Server)
|
||||
if err := c.BodyParser(server); err != nil {
|
||||
return ac.errorHandler.HandleParsingError(c, err)
|
||||
}
|
||||
server.ID = serverID
|
||||
|
||||
if err := ac.service.UpdateServer(c, server); err != nil {
|
||||
return ac.errorHandler.HandleServiceError(c, err)
|
||||
}
|
||||
return c.JSON(server)
|
||||
}
|
||||
|
||||
// DeleteServer deletes an existing server
|
||||
// @Summary Delete an ACC server
|
||||
// @Description Delete an existing ACC server
|
||||
|
||||
@@ -17,7 +17,6 @@ type WebSocketController struct {
|
||||
jwtHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewWebSocketController initializes WebSocketController
|
||||
func NewWebSocketController(
|
||||
wsService *service.WebSocketService,
|
||||
jwtHandler *jwt.OpenJWTHandler,
|
||||
@@ -29,7 +28,6 @@ func NewWebSocketController(
|
||||
jwtHandler: jwtHandler,
|
||||
}
|
||||
|
||||
// WebSocket routes
|
||||
wsRoutes := routeGroups.WebSocket
|
||||
wsRoutes.Use("/", wsc.upgradeWebSocket)
|
||||
wsRoutes.Get("/", websocket.New(wsc.handleWebSocket))
|
||||
@@ -37,11 +35,8 @@ func NewWebSocketController(
|
||||
return wsc
|
||||
}
|
||||
|
||||
// upgradeWebSocket middleware to upgrade HTTP to WebSocket and validate authentication
|
||||
func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
|
||||
// Check if it's a WebSocket upgrade request
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
// Validate JWT token from query parameter or header
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
token = c.Get("Authorization")
|
||||
@@ -54,21 +49,18 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Missing authentication token")
|
||||
}
|
||||
|
||||
// Validate the token
|
||||
claims, err := wsc.jwtHandler.ValidateToken(token)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Invalid authentication token")
|
||||
}
|
||||
|
||||
// Parse UserID string to UUID
|
||||
userID, err := uuid.Parse(claims.UserID)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID in token")
|
||||
}
|
||||
|
||||
// Store user info in context for use in WebSocket handler
|
||||
c.Locals("userID", userID)
|
||||
c.Locals("username", claims.UserID) // Use UserID as username for now
|
||||
c.Locals("username", claims.UserID)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
@@ -76,12 +68,9 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
|
||||
return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required")
|
||||
}
|
||||
|
||||
// handleWebSocket handles WebSocket connections
|
||||
func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
|
||||
// Generate a unique connection ID
|
||||
connID := uuid.New().String()
|
||||
|
||||
// Get user info from locals (set by middleware)
|
||||
userID, ok := c.Locals("userID").(uuid.UUID)
|
||||
if !ok {
|
||||
logging.Error("Failed to get user ID from WebSocket connection")
|
||||
@@ -92,16 +81,13 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
|
||||
username, _ := c.Locals("username").(string)
|
||||
logging.Info("WebSocket connection established for user: %s (ID: %s)", username, userID.String())
|
||||
|
||||
// Add the connection to the service
|
||||
wsc.webSocketService.AddConnection(connID, c, &userID)
|
||||
|
||||
// Handle connection cleanup
|
||||
defer func() {
|
||||
wsc.webSocketService.RemoveConnection(connID)
|
||||
logging.Info("WebSocket connection closed for user: %s", username)
|
||||
}()
|
||||
|
||||
// Handle incoming messages from the client
|
||||
for {
|
||||
messageType, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
@@ -111,14 +97,12 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle different message types
|
||||
switch messageType {
|
||||
case websocket.TextMessage:
|
||||
wsc.handleTextMessage(connID, userID, message)
|
||||
case websocket.BinaryMessage:
|
||||
logging.Debug("Received binary message from user %s (not supported)", username)
|
||||
case websocket.PingMessage:
|
||||
// Respond with pong
|
||||
if err := c.WriteMessage(websocket.PongMessage, nil); err != nil {
|
||||
logging.Error("Failed to send pong to user %s: %v", username, err)
|
||||
break
|
||||
@@ -127,21 +111,11 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleTextMessage processes text messages from the client
|
||||
func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUID, message []byte) {
|
||||
logging.Debug("Received WebSocket message from user %s: %s", userID.String(), string(message))
|
||||
|
||||
// Parse the message to handle different types of client requests
|
||||
// For now, we'll just log it. In the future, you might want to handle:
|
||||
// - Subscription to specific server creation processes
|
||||
// - Client heartbeat/keepalive
|
||||
// - Request for status updates
|
||||
|
||||
// Example: If the message contains a server ID, associate this connection with that server
|
||||
// This is a simple implementation - you might want to use proper JSON parsing
|
||||
messageStr := string(message)
|
||||
if len(messageStr) > 10 && messageStr[:9] == "server_id" {
|
||||
// Extract server ID from message like "server_id:uuid"
|
||||
if serverIDStr := messageStr[10:]; len(serverIDStr) > 0 {
|
||||
if serverID, err := uuid.Parse(serverIDStr); err == nil {
|
||||
wsc.webSocketService.SetServerID(connID, serverID)
|
||||
@@ -151,18 +125,14 @@ func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUI
|
||||
}
|
||||
}
|
||||
|
||||
// GetWebSocketUpgrade returns the WebSocket upgrade handler for use in other controllers
|
||||
func (wsc *WebSocketController) GetWebSocketUpgrade() fiber.Handler {
|
||||
return wsc.upgradeWebSocket
|
||||
}
|
||||
|
||||
// GetWebSocketHandler returns the WebSocket connection handler for use in other controllers
|
||||
func (wsc *WebSocketController) GetWebSocketHandler() func(*websocket.Conn) {
|
||||
return wsc.handleWebSocket
|
||||
}
|
||||
|
||||
// BroadcastServerCreationProgress is a helper method for other services to broadcast progress
|
||||
func (wsc *WebSocketController) BroadcastServerCreationProgress(serverID uuid.UUID, step string, status string, message string) {
|
||||
// This can be used by the ServerService during server creation
|
||||
logging.Info("Broadcasting server creation progress: %s - %s: %s", serverID.String(), step, status)
|
||||
}
|
||||
|
||||
@@ -10,12 +10,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AccessKeyMiddleware provides authentication and permission middleware.
|
||||
type AccessKeyMiddleware struct {
|
||||
userInfo CachedUserInfo
|
||||
}
|
||||
|
||||
// NewAccessKeyMiddleware creates a new AccessKeyMiddleware.
|
||||
func NewAccessKeyMiddleware() *AccessKeyMiddleware {
|
||||
auth := &AccessKeyMiddleware{
|
||||
userInfo: CachedUserInfo{UserID: uuid.New().String(), Username: "access_key", RoleName: "Admin", Permissions: map[string]bool{
|
||||
@@ -25,9 +23,7 @@ func NewAccessKeyMiddleware() *AccessKeyMiddleware {
|
||||
return auth
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AccessKeyMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
// Log authentication attempt
|
||||
ip := ctx.IP()
|
||||
userAgent := ctx.Get("User-Agent")
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CachedUserInfo holds cached user authentication and permission data
|
||||
type CachedUserInfo struct {
|
||||
UserID string
|
||||
Username string
|
||||
@@ -25,7 +24,6 @@ type CachedUserInfo struct {
|
||||
CachedAt time.Time
|
||||
}
|
||||
|
||||
// AuthMiddleware provides authentication and permission middleware.
|
||||
type AuthMiddleware struct {
|
||||
membershipService *service.MembershipService
|
||||
cache *cache.InMemoryCache
|
||||
@@ -34,7 +32,6 @@ type AuthMiddleware struct {
|
||||
openJWTHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new AuthMiddleware.
|
||||
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware {
|
||||
auth := &AuthMiddleware{
|
||||
membershipService: ms,
|
||||
@@ -44,24 +41,20 @@ func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache
|
||||
openJWTHandler: openJWTHandler,
|
||||
}
|
||||
|
||||
// Set up bidirectional relationship for cache invalidation
|
||||
ms.SetCacheInvalidator(auth)
|
||||
|
||||
return auth
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error {
|
||||
return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, true, ctx)
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
return m.AuthenticateWithHandler(m.jwtHandler, false, ctx)
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isOpenToken bool, ctx *fiber.Ctx) error {
|
||||
// Log authentication attempt
|
||||
ip := ctx.IP()
|
||||
userAgent := ctx.Get("User-Agent")
|
||||
|
||||
@@ -89,7 +82,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
|
||||
})
|
||||
}
|
||||
|
||||
// Validate token length to prevent potential attacks
|
||||
token := parts[1]
|
||||
if len(token) < 10 || len(token) > 2048 {
|
||||
logging.Error("Authentication failed: invalid token length from IP %s", ip)
|
||||
@@ -113,7 +105,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
|
||||
})
|
||||
}
|
||||
|
||||
// Additional security: validate user ID format
|
||||
if claims.UserID == "" || len(claims.UserID) < 10 {
|
||||
logging.Error("Authentication failed: invalid user ID in token from IP %s", ip)
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
@@ -127,7 +118,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
|
||||
ctx.Locals("userInfo", userInfo)
|
||||
ctx.Locals("authTime", time.Now())
|
||||
} else {
|
||||
// Preload and cache user info to avoid database queries on permission checks
|
||||
userInfo, err := m.getCachedUserInfo(ctx.UserContext(), claims.UserID)
|
||||
if err != nil {
|
||||
logging.Error("Authentication failed: unable to load user info for %s from IP %s: %v", claims.UserID, ip, err)
|
||||
@@ -145,7 +135,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
// HasPermission is a middleware for checking user permissions with enhanced logging.
|
||||
func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
userID, ok := ctx.Locals("userID").(string)
|
||||
@@ -160,7 +149,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
// Validate permission parameter
|
||||
if requiredPermission == "" {
|
||||
logging.Error("Permission check failed: empty permission requirement")
|
||||
return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
@@ -168,7 +156,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
})
|
||||
}
|
||||
|
||||
// Use cached user info from authentication step - no database queries needed
|
||||
userInfo, ok := ctx.Locals("userInfo").(*CachedUserInfo)
|
||||
if !ok {
|
||||
logging.Error("Permission check failed: no cached user info in context from IP %s", ctx.IP())
|
||||
@@ -177,7 +164,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
})
|
||||
}
|
||||
|
||||
// Check if user has permission using cached data
|
||||
has := m.hasPermissionFromCache(userInfo, requiredPermission)
|
||||
|
||||
if !has {
|
||||
@@ -192,16 +178,13 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit applies rate limiting specifically for authentication endpoints
|
||||
func (m *AuthMiddleware) AuthRateLimit() fiber.Handler {
|
||||
return m.securityMW.AuthRateLimit()
|
||||
}
|
||||
|
||||
// RequireHTTPS redirects HTTP requests to HTTPS in production
|
||||
func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" {
|
||||
// Allow HTTP in development/testing
|
||||
if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" {
|
||||
httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL()
|
||||
return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently)
|
||||
@@ -211,11 +194,9 @@ func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// getCachedUserInfo retrieves and caches complete user information including permissions
|
||||
func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (*CachedUserInfo, error) {
|
||||
cacheKey := fmt.Sprintf("userinfo:%s", userID)
|
||||
|
||||
// Try cache first
|
||||
if cached, found := m.cache.Get(cacheKey); found {
|
||||
if userInfo, ok := cached.(*CachedUserInfo); ok {
|
||||
logging.DebugWithContext("AUTH_CACHE", "User info for %s found in cache", userID)
|
||||
@@ -223,13 +204,11 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss - load from database
|
||||
user, err := m.membershipService.GetUserWithPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build permission map for fast lookups
|
||||
permissions := make(map[string]bool)
|
||||
for _, p := range user.Role.Permissions {
|
||||
permissions[p.Name] = true
|
||||
@@ -243,34 +222,26 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
|
||||
CachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Cache for 15 minutes
|
||||
m.cache.Set(cacheKey, userInfo, 15*time.Minute)
|
||||
logging.DebugWithContext("AUTH_CACHE", "User info for %s cached with %d permissions", userID, len(permissions))
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// hasPermissionFromCache checks permissions using cached user info (no database queries)
|
||||
func (m *AuthMiddleware) hasPermissionFromCache(userInfo *CachedUserInfo, permission string) bool {
|
||||
// Super Admin and Admin have all permissions
|
||||
if userInfo.RoleName == "Super Admin" || userInfo.RoleName == "Admin" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check specific permission in cached map
|
||||
return userInfo.Permissions[permission]
|
||||
}
|
||||
|
||||
// InvalidateUserPermissions removes cached user info for a user
|
||||
func (m *AuthMiddleware) InvalidateUserPermissions(userID string) {
|
||||
cacheKey := fmt.Sprintf("userinfo:%s", userID)
|
||||
m.cache.Delete(cacheKey)
|
||||
logging.InfoWithContext("AUTH_CACHE", "User info cache invalidated for user %s", userID)
|
||||
}
|
||||
|
||||
// InvalidateAllUserPermissions clears all cached user info (useful for role/permission changes)
|
||||
func (m *AuthMiddleware) InvalidateAllUserPermissions() {
|
||||
// This would need to be implemented based on your cache interface
|
||||
// For now, just log that invalidation was requested
|
||||
logging.InfoWithContext("AUTH_CACHE", "All user info caches invalidation requested")
|
||||
}
|
||||
|
||||
@@ -7,25 +7,20 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// RequestLoggingMiddleware logs HTTP requests and responses
|
||||
type RequestLoggingMiddleware struct {
|
||||
infoLogger *logging.InfoLogger
|
||||
}
|
||||
|
||||
// NewRequestLoggingMiddleware creates a new request logging middleware
|
||||
func NewRequestLoggingMiddleware() *RequestLoggingMiddleware {
|
||||
return &RequestLoggingMiddleware{
|
||||
infoLogger: logging.GetInfoLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns the middleware handler function
|
||||
func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Record start time
|
||||
start := time.Now()
|
||||
|
||||
// Log incoming request
|
||||
userAgent := c.Get("User-Agent")
|
||||
if userAgent == "" {
|
||||
userAgent = "Unknown"
|
||||
@@ -33,17 +28,13 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
|
||||
|
||||
rlm.infoLogger.LogRequest(c.Method(), c.OriginalURL(), userAgent)
|
||||
|
||||
// Continue to next handler
|
||||
err := c.Next()
|
||||
|
||||
// Calculate duration
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log response
|
||||
statusCode := c.Response().StatusCode()
|
||||
rlm.infoLogger.LogResponse(c.Method(), c.OriginalURL(), statusCode, duration.String())
|
||||
|
||||
// Log error if present
|
||||
if err != nil {
|
||||
logging.ErrorWithContext("REQUEST_MIDDLEWARE", "Request failed: %v", err)
|
||||
}
|
||||
@@ -52,10 +43,8 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Global request logging middleware instance
|
||||
var globalRequestLoggingMiddleware *RequestLoggingMiddleware
|
||||
|
||||
// GetRequestLoggingMiddleware returns the global request logging middleware
|
||||
func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
|
||||
if globalRequestLoggingMiddleware == nil {
|
||||
globalRequestLoggingMiddleware = NewRequestLoggingMiddleware()
|
||||
@@ -63,7 +52,6 @@ func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
|
||||
return globalRequestLoggingMiddleware
|
||||
}
|
||||
|
||||
// Handler returns the global request logging middleware handler
|
||||
func Handler() fiber.Handler {
|
||||
return GetRequestLoggingMiddleware().Handler()
|
||||
}
|
||||
|
||||
@@ -11,19 +11,16 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// RateLimiter stores rate limiting information
|
||||
type RateLimiter struct {
|
||||
requests map[string][]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// Use graceful shutdown for cleanup goroutine
|
||||
shutdownManager := graceful.GetManager()
|
||||
shutdownManager.RunGoroutine(func(ctx context.Context) {
|
||||
rl.cleanupWithContext(ctx)
|
||||
@@ -32,7 +29,6 @@ func NewRateLimiter() *RateLimiter {
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup removes old entries from the rate limiter
|
||||
func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
@@ -62,44 +58,29 @@ func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityMiddleware provides comprehensive security middleware
|
||||
type SecurityMiddleware struct {
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware() *SecurityMiddleware {
|
||||
return &SecurityMiddleware{
|
||||
rateLimiter: NewRateLimiter(),
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityHeaders adds security headers to responses
|
||||
func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Prevent MIME type sniffing
|
||||
c.Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Prevent clickjacking
|
||||
c.Set("X-Frame-Options", "DENY")
|
||||
|
||||
// Enable XSS protection
|
||||
c.Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Prevent referrer leakage
|
||||
c.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Content Security Policy
|
||||
c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'")
|
||||
|
||||
// Permissions Policy
|
||||
c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()")
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit implements rate limiting for API endpoints
|
||||
func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ip := c.IP()
|
||||
@@ -111,7 +92,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
now := time.Now()
|
||||
requests := sm.rateLimiter.requests[key]
|
||||
|
||||
// Remove requests older than duration
|
||||
filtered := make([]time.Time, 0, len(requests))
|
||||
for _, t := range requests {
|
||||
if now.Sub(t) < duration {
|
||||
@@ -119,7 +99,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if limit is exceeded
|
||||
if len(filtered) >= maxRequests {
|
||||
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
|
||||
"error": "Rate limit exceeded",
|
||||
@@ -127,7 +106,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
})
|
||||
}
|
||||
|
||||
// Add current request
|
||||
filtered = append(filtered, now)
|
||||
sm.rateLimiter.requests[key] = filtered
|
||||
|
||||
@@ -135,7 +113,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit implements stricter rate limiting for authentication endpoints
|
||||
func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ip := c.IP()
|
||||
@@ -148,7 +125,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
now := time.Now()
|
||||
requests := sm.rateLimiter.requests[key]
|
||||
|
||||
// Remove requests older than 15 minutes
|
||||
filtered := make([]time.Time, 0, len(requests))
|
||||
for _, t := range requests {
|
||||
if now.Sub(t) < 15*time.Minute {
|
||||
@@ -156,7 +132,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if limit is exceeded (5 requests per 15 minutes for auth)
|
||||
if len(filtered) >= 5 {
|
||||
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
|
||||
"error": "Too many authentication attempts",
|
||||
@@ -164,7 +139,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// Add current request
|
||||
filtered = append(filtered, now)
|
||||
sm.rateLimiter.requests[key] = filtered
|
||||
|
||||
@@ -172,20 +146,16 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// InputSanitization sanitizes user input to prevent XSS and injection attacks
|
||||
func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Sanitize query parameters
|
||||
c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
|
||||
sanitized := sanitizeInput(string(value))
|
||||
c.Request().URI().QueryArgs().Set(string(key), sanitized)
|
||||
})
|
||||
|
||||
// Store original body for processing
|
||||
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
|
||||
body := c.Body()
|
||||
if len(body) > 0 {
|
||||
// Basic sanitization - remove potentially dangerous patterns
|
||||
sanitized := sanitizeInput(string(body))
|
||||
c.Request().SetBodyString(sanitized)
|
||||
}
|
||||
@@ -195,7 +165,6 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeInput removes potentially dangerous patterns from input
|
||||
func sanitizeInput(input string) string {
|
||||
dangerous := []string{
|
||||
"<script",
|
||||
@@ -236,7 +205,7 @@ func sanitizeInput(input string) string {
|
||||
|
||||
result := input
|
||||
lowerInput := strings.ToLower(input)
|
||||
|
||||
|
||||
for _, pattern := range dangerous {
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return ""
|
||||
@@ -254,7 +223,6 @@ func sanitizeInput(input string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateContentType ensures only expected content types are accepted
|
||||
func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
|
||||
@@ -265,7 +233,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
|
||||
})
|
||||
}
|
||||
|
||||
// Check if content type is allowed
|
||||
allowed := false
|
||||
for _, allowedType := range allowedTypes {
|
||||
if strings.Contains(contentType, allowedType) {
|
||||
@@ -285,7 +252,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUserAgent blocks requests with suspicious or missing user agents
|
||||
func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
|
||||
suspiciousAgents := []string{
|
||||
"sqlmap",
|
||||
@@ -296,21 +262,19 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
|
||||
"dirb",
|
||||
"dirbuster",
|
||||
"wpscan",
|
||||
"curl/7.0", // Very old curl versions
|
||||
"wget/1.0", // Very old wget versions
|
||||
"curl/7.0",
|
||||
"wget/1.0",
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
userAgent := strings.ToLower(c.Get("User-Agent"))
|
||||
|
||||
// Block empty user agents
|
||||
if userAgent == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "User-Agent header is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Block suspicious user agents
|
||||
for _, suspicious := range suspiciousAgents {
|
||||
if strings.Contains(userAgent, suspicious) {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
||||
@@ -323,7 +287,6 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// RequestSizeLimit limits the size of incoming requests
|
||||
func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
|
||||
@@ -340,19 +303,15 @@ func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// LogSecurityEvents logs security-related events
|
||||
func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
start := time.Now()
|
||||
|
||||
// Process request
|
||||
err := c.Next()
|
||||
|
||||
// Log suspicious activity
|
||||
status := c.Response().StatusCode()
|
||||
if status == 401 || status == 403 || status == 429 {
|
||||
duration := time.Since(start)
|
||||
// In a real implementation, you would send this to your logging system
|
||||
fmt.Printf("[SECURITY] %s %s %s %d %v %s\n",
|
||||
time.Now().Format(time.RFC3339),
|
||||
c.IP(),
|
||||
@@ -367,7 +326,6 @@ func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// TimeoutMiddleware adds request timeout
|
||||
func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
|
||||
|
||||
@@ -9,21 +9,17 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Migration001UpgradePasswordSecurity migrates existing user passwords from encrypted to hashed format
|
||||
type Migration001UpgradePasswordSecurity struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewMigration001UpgradePasswordSecurity creates a new password security migration
|
||||
func NewMigration001UpgradePasswordSecurity(db *gorm.DB) *Migration001UpgradePasswordSecurity {
|
||||
return &Migration001UpgradePasswordSecurity{DB: db}
|
||||
}
|
||||
|
||||
// Up executes the migration
|
||||
func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
logging.Info("Starting password security upgrade migration...")
|
||||
|
||||
// Check if migration has already been applied
|
||||
var migrationRecord MigrationRecord
|
||||
err := m.DB.Where("migration_name = ?", "001_upgrade_password_security").First(&migrationRecord).Error
|
||||
if err == nil {
|
||||
@@ -31,12 +27,10 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create migration tracking table if it doesn't exist
|
||||
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to create migration tracking table: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx := m.DB.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to start transaction: %v", tx.Error)
|
||||
@@ -47,16 +41,13 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
}
|
||||
}()
|
||||
|
||||
// Add a backup column for old passwords (temporary)
|
||||
if err := tx.Exec("ALTER TABLE users ADD COLUMN password_backup TEXT").Error; err != nil {
|
||||
// Column might already exist, ignore if it's a duplicate column error
|
||||
if !isDuplicateColumnError(err) {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to add backup column: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all users with encrypted passwords
|
||||
var users []UserForMigration
|
||||
if err := tx.Find(&users).Error; err != nil {
|
||||
tx.Rollback()
|
||||
@@ -72,19 +63,15 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
if err := m.migrateUserPassword(tx, &user); err != nil {
|
||||
logging.Error("Failed to migrate user %s (ID: %s): %v", user.Username, user.ID, err)
|
||||
failedCount++
|
||||
// Continue with other users rather than failing completely
|
||||
continue
|
||||
}
|
||||
migratedCount++
|
||||
}
|
||||
|
||||
// Remove backup column after successful migration
|
||||
if err := tx.Exec("ALTER TABLE users DROP COLUMN password_backup").Error; err != nil {
|
||||
logging.Error("Failed to remove backup column (non-critical): %v", err)
|
||||
// Don't fail the migration for this
|
||||
}
|
||||
|
||||
// Record successful migration
|
||||
migrationRecord = MigrationRecord{
|
||||
MigrationName: "001_upgrade_password_security",
|
||||
AppliedAt: "datetime('now')",
|
||||
@@ -97,7 +84,6 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
return fmt.Errorf("failed to record migration: %v", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration: %v", err)
|
||||
}
|
||||
@@ -111,32 +97,24 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateUserPassword migrates a single user's password
|
||||
func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, user *UserForMigration) error {
|
||||
// Skip if password is already hashed (bcrypt hashes start with $2a$, $2b$, or $2y$)
|
||||
if isAlreadyHashed(user.Password) {
|
||||
logging.Debug("User %s already has hashed password, skipping", user.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Backup original password
|
||||
if err := tx.Model(user).Update("password_backup", user.Password).Error; err != nil {
|
||||
return fmt.Errorf("failed to backup password: %v", err)
|
||||
}
|
||||
|
||||
// Try to decrypt the old password
|
||||
var plainPassword string
|
||||
|
||||
// First, try to decrypt using the old encryption method
|
||||
decrypted, err := decryptOldPassword(user.Password)
|
||||
if err != nil {
|
||||
// If decryption fails, the password might already be plain text or corrupted
|
||||
logging.Error("Failed to decrypt password for user %s, treating as plain text: %v", user.Username, err)
|
||||
|
||||
// Use original password as-is (might be plain text from development)
|
||||
plainPassword = user.Password
|
||||
|
||||
// Validate it's not obviously encrypted data
|
||||
if len(plainPassword) > 100 || containsBinaryData(plainPassword) {
|
||||
return fmt.Errorf("password appears to be corrupted encrypted data")
|
||||
}
|
||||
@@ -144,7 +122,6 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
|
||||
plainPassword = decrypted
|
||||
}
|
||||
|
||||
// Validate plain password
|
||||
if plainPassword == "" {
|
||||
return errors.New("decrypted password is empty")
|
||||
}
|
||||
@@ -153,13 +130,11 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
|
||||
return errors.New("password too short after decryption")
|
||||
}
|
||||
|
||||
// Hash the plain password using bcrypt
|
||||
hashedPassword, err := password.HashPassword(plainPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
// Update with hashed password
|
||||
if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil {
|
||||
return fmt.Errorf("failed to update password: %v", err)
|
||||
}
|
||||
@@ -168,19 +143,16 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserForMigration represents a user record for migration purposes
|
||||
type UserForMigration struct {
|
||||
ID string `gorm:"column:id"`
|
||||
Username string `gorm:"column:username"`
|
||||
Password string `gorm:"column:password"`
|
||||
}
|
||||
|
||||
// TableName specifies the table name for GORM
|
||||
func (UserForMigration) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// MigrationRecord tracks applied migrations
|
||||
type MigrationRecord struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
MigrationName string `gorm:"unique;not null"`
|
||||
@@ -189,49 +161,38 @@ type MigrationRecord struct {
|
||||
Notes string
|
||||
}
|
||||
|
||||
// TableName specifies the table name for GORM
|
||||
func (MigrationRecord) TableName() string {
|
||||
return "migration_records"
|
||||
}
|
||||
|
||||
// isAlreadyHashed checks if a password is already bcrypt hashed
|
||||
func isAlreadyHashed(password string) bool {
|
||||
return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$")
|
||||
}
|
||||
|
||||
// containsBinaryData checks if a string contains binary data
|
||||
func containsBinaryData(s string) bool {
|
||||
for _, b := range []byte(s) {
|
||||
if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return
|
||||
if b < 32 && b != 9 && b != 10 && b != 13 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isDuplicateColumnError checks if an error is due to duplicate column
|
||||
func isDuplicateColumnError(err error) bool {
|
||||
errStr := err.Error()
|
||||
return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" ||
|
||||
fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup"
|
||||
}
|
||||
|
||||
// decryptOldPassword attempts to decrypt using the old encryption method
|
||||
// This is a simplified version of the old DecryptPassword function
|
||||
func decryptOldPassword(encryptedPassword string) (string, error) {
|
||||
// This would use the old decryption logic
|
||||
// For now, we'll return an error to force treating as plain text
|
||||
// In a real scenario, you'd implement the old decryption here
|
||||
return "", errors.New("old decryption not implemented - treating as plain text")
|
||||
}
|
||||
|
||||
// Down reverses the migration (if needed)
|
||||
func (m *Migration001UpgradePasswordSecurity) Down() error {
|
||||
logging.Error("Password security migration rollback is not supported for security reasons")
|
||||
return errors.New("password security migration rollback is not supported")
|
||||
}
|
||||
|
||||
// RunMigration is a convenience function to run the migration
|
||||
func RunPasswordSecurityMigration(db *gorm.DB) error {
|
||||
migration := NewMigration001UpgradePasswordSecurity(db)
|
||||
return migration.Up()
|
||||
|
||||
@@ -9,21 +9,17 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Migration002MigrateToUUID migrates tables from integer IDs to UUIDs
|
||||
type Migration002MigrateToUUID struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewMigration002MigrateToUUID creates a new UUID migration
|
||||
func NewMigration002MigrateToUUID(db *gorm.DB) *Migration002MigrateToUUID {
|
||||
return &Migration002MigrateToUUID{DB: db}
|
||||
}
|
||||
|
||||
// Up executes the migration
|
||||
func (m *Migration002MigrateToUUID) Up() error {
|
||||
logging.Info("Checking UUID migration...")
|
||||
|
||||
// Check if migration is needed by looking at the servers table structure
|
||||
if !m.needsMigration() {
|
||||
logging.Info("UUID migration not needed - tables already use UUID primary keys")
|
||||
return nil
|
||||
@@ -31,7 +27,6 @@ func (m *Migration002MigrateToUUID) Up() error {
|
||||
|
||||
logging.Info("Starting UUID migration...")
|
||||
|
||||
// Check if migration has already been applied
|
||||
var migrationRecord MigrationRecord
|
||||
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
|
||||
if err == nil {
|
||||
@@ -39,12 +34,10 @@ func (m *Migration002MigrateToUUID) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create migration tracking table if it doesn't exist
|
||||
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to create migration tracking table: %v", err)
|
||||
}
|
||||
|
||||
// Execute the UUID migration using the existing migration function
|
||||
logging.Info("Executing UUID migration...")
|
||||
if err := runUUIDMigrationSQL(m.DB); err != nil {
|
||||
return fmt.Errorf("failed to execute UUID migration: %v", err)
|
||||
@@ -54,9 +47,7 @@ func (m *Migration002MigrateToUUID) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// needsMigration checks if the UUID migration is needed by examining table structure
|
||||
func (m *Migration002MigrateToUUID) needsMigration() bool {
|
||||
// Check if servers table exists and has integer primary key
|
||||
var result struct {
|
||||
Type string `gorm:"column:type"`
|
||||
}
|
||||
@@ -67,29 +58,22 @@ func (m *Migration002MigrateToUUID) needsMigration() bool {
|
||||
`).Scan(&result).Error
|
||||
|
||||
if err != nil || result.Type == "" {
|
||||
// Table doesn't exist or no primary key found - assume no migration needed
|
||||
return false
|
||||
}
|
||||
|
||||
// If the primary key is INTEGER, we need migration
|
||||
// If it's TEXT (UUID), migration already done
|
||||
return result.Type == "INTEGER" || result.Type == "integer"
|
||||
}
|
||||
|
||||
// Down reverses the migration (not implemented for safety)
|
||||
func (m *Migration002MigrateToUUID) Down() error {
|
||||
logging.Error("UUID migration rollback is not supported for data safety reasons")
|
||||
return fmt.Errorf("UUID migration rollback is not supported")
|
||||
}
|
||||
|
||||
// runUUIDMigrationSQL executes the UUID migration using the SQL file
|
||||
func runUUIDMigrationSQL(db *gorm.DB) error {
|
||||
// Disable foreign key constraints during migration
|
||||
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
|
||||
return fmt.Errorf("failed to disable foreign keys: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to start transaction: %v", tx.Error)
|
||||
@@ -101,25 +85,21 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// Read the migration SQL from file
|
||||
sqlPath := filepath.Join("scripts", "migrations", "002_migrate_servers_to_uuid.sql")
|
||||
migrationSQL, err := ioutil.ReadFile(sqlPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migration SQL file: %v", err)
|
||||
}
|
||||
|
||||
// Execute the migration
|
||||
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to execute migration: %v", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration: %v", err)
|
||||
}
|
||||
|
||||
// Re-enable foreign key constraints
|
||||
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
|
||||
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
|
||||
}
|
||||
@@ -127,7 +107,6 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunUUIDMigration is a convenience function to run the migration
|
||||
func RunUUIDMigration(db *gorm.DB) error {
|
||||
migration := NewMigration002MigrateToUUID(db)
|
||||
return migration.Up()
|
||||
|
||||
@@ -7,21 +7,17 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UpdateStateHistorySessions migrates tables from integer IDs to UUIDs
|
||||
type UpdateStateHistorySessions struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewUpdateStateHistorySessions creates a new UUID migration
|
||||
func NewUpdateStateHistorySessions(db *gorm.DB) *UpdateStateHistorySessions {
|
||||
return &UpdateStateHistorySessions{DB: db}
|
||||
}
|
||||
|
||||
// Up executes the migration
|
||||
func (m *UpdateStateHistorySessions) Up() error {
|
||||
logging.Info("Checking UUID migration...")
|
||||
|
||||
// Check if migration is needed by looking at the servers table structure
|
||||
if !m.needsMigration() {
|
||||
logging.Info("UUID migration not needed - tables already use UUID primary keys")
|
||||
return nil
|
||||
@@ -29,7 +25,6 @@ func (m *UpdateStateHistorySessions) Up() error {
|
||||
|
||||
logging.Info("Starting UUID migration...")
|
||||
|
||||
// Check if migration has already been applied
|
||||
var migrationRecord MigrationRecord
|
||||
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
|
||||
if err == nil {
|
||||
@@ -37,12 +32,10 @@ func (m *UpdateStateHistorySessions) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create migration tracking table if it doesn't exist
|
||||
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to create migration tracking table: %v", err)
|
||||
}
|
||||
|
||||
// Execute the UUID migration using the existing migration function
|
||||
logging.Info("Executing UUID migration...")
|
||||
if err := runUUIDMigrationSQL(m.DB); err != nil {
|
||||
return fmt.Errorf("failed to execute UUID migration: %v", err)
|
||||
@@ -52,9 +45,7 @@ func (m *UpdateStateHistorySessions) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// needsMigration checks if the UUID migration is needed by examining table structure
|
||||
func (m *UpdateStateHistorySessions) needsMigration() bool {
|
||||
// Check if servers table exists and has integer primary key
|
||||
var result struct {
|
||||
Exists bool `gorm:"column:exists"`
|
||||
}
|
||||
@@ -65,26 +56,21 @@ func (m *UpdateStateHistorySessions) needsMigration() bool {
|
||||
`).Scan(&result).Error
|
||||
|
||||
if err != nil || !result.Exists {
|
||||
// Table doesn't exist or no primary key found - assume no migration needed
|
||||
return false
|
||||
}
|
||||
return result.Exists
|
||||
}
|
||||
|
||||
// Down reverses the migration (not implemented for safety)
|
||||
func (m *UpdateStateHistorySessions) Down() error {
|
||||
logging.Error("UUID migration rollback is not supported for data safety reasons")
|
||||
return fmt.Errorf("UUID migration rollback is not supported")
|
||||
}
|
||||
|
||||
// runUpdateStateHistorySessionsMigration executes the UUID migration using the SQL file
|
||||
func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
|
||||
// Disable foreign key constraints during migration
|
||||
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
|
||||
return fmt.Errorf("failed to disable foreign keys: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to start transaction: %v", tx.Error)
|
||||
@@ -98,18 +84,15 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
|
||||
|
||||
migrationSQL := "UPDATE state_history SET session = upper(substr(session, 1, 1));"
|
||||
|
||||
// Execute the migration
|
||||
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to execute migration: %v", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration: %v", err)
|
||||
}
|
||||
|
||||
// Re-enable foreign key constraints
|
||||
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
|
||||
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
|
||||
}
|
||||
@@ -117,7 +100,6 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunUpdateStateHistorySessionsMigration is a convenience function to run the migration
|
||||
func RunUpdateStateHistorySessionsMigration(db *gorm.DB) error {
|
||||
migration := NewUpdateStateHistorySessions(db)
|
||||
return migration.Up()
|
||||
|
||||
@@ -6,20 +6,17 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusCache represents a cached server status with expiration
|
||||
type StatusCache struct {
|
||||
Status ServiceStatus
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// CacheConfig holds configuration for cache behavior
|
||||
type CacheConfig struct {
|
||||
ExpirationTime time.Duration // How long before a cache entry expires
|
||||
ThrottleTime time.Duration // Minimum time between status checks
|
||||
DefaultStatus ServiceStatus // Default status to return when throttled
|
||||
ExpirationTime time.Duration
|
||||
ThrottleTime time.Duration
|
||||
DefaultStatus ServiceStatus
|
||||
}
|
||||
|
||||
// ServerStatusCache manages cached server statuses
|
||||
type ServerStatusCache struct {
|
||||
sync.RWMutex
|
||||
cache map[string]*StatusCache
|
||||
@@ -27,7 +24,6 @@ type ServerStatusCache struct {
|
||||
lastChecked map[string]time.Time
|
||||
}
|
||||
|
||||
// NewServerStatusCache creates a new server status cache
|
||||
func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
|
||||
return &ServerStatusCache{
|
||||
cache: make(map[string]*StatusCache),
|
||||
@@ -36,12 +32,10 @@ func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus retrieves the cached status or indicates if a fresh check is needed
|
||||
func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
// Check if we're being throttled
|
||||
if lastCheck, exists := c.lastChecked[serviceName]; exists {
|
||||
if time.Since(lastCheck) < c.config.ThrottleTime {
|
||||
if cached, ok := c.cache[serviceName]; ok {
|
||||
@@ -51,7 +45,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have a valid cached entry
|
||||
if cached, ok := c.cache[serviceName]; ok {
|
||||
if time.Since(cached.UpdatedAt) < c.config.ExpirationTime {
|
||||
return cached.Status, false
|
||||
@@ -61,7 +54,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
|
||||
return StatusUnknown, true
|
||||
}
|
||||
|
||||
// UpdateStatus updates the cache with a new status
|
||||
func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatus) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -73,7 +65,6 @@ func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatu
|
||||
c.lastChecked[serviceName] = time.Now()
|
||||
}
|
||||
|
||||
// InvalidateStatus removes a specific service from the cache
|
||||
func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -82,7 +73,6 @@ func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
|
||||
delete(c.lastChecked, serviceName)
|
||||
}
|
||||
|
||||
// Clear removes all entries from the cache
|
||||
func (c *ServerStatusCache) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -91,13 +81,11 @@ func (c *ServerStatusCache) Clear() {
|
||||
c.lastChecked = make(map[string]time.Time)
|
||||
}
|
||||
|
||||
// LookupCache provides a generic cache for lookup data
|
||||
type LookupCache struct {
|
||||
sync.RWMutex
|
||||
data map[string]interface{}
|
||||
}
|
||||
|
||||
// NewLookupCache creates a new lookup cache
|
||||
func NewLookupCache() *LookupCache {
|
||||
logging.Debug("Initializing new LookupCache")
|
||||
return &LookupCache{
|
||||
@@ -105,7 +93,6 @@ func NewLookupCache() *LookupCache {
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a cached value by key
|
||||
func (c *LookupCache) Get(key string) (interface{}, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -119,7 +106,6 @@ func (c *LookupCache) Get(key string) (interface{}, bool) {
|
||||
return value, exists
|
||||
}
|
||||
|
||||
// Set stores a value in the cache
|
||||
func (c *LookupCache) Set(key string, value interface{}) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -128,7 +114,6 @@ func (c *LookupCache) Set(key string, value interface{}) {
|
||||
logging.Debug("Cache SET for key: %s", key)
|
||||
}
|
||||
|
||||
// Clear removes all entries from the cache
|
||||
func (c *LookupCache) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -137,13 +122,11 @@ func (c *LookupCache) Clear() {
|
||||
logging.Debug("Cache CLEARED")
|
||||
}
|
||||
|
||||
// ConfigEntry represents a cached configuration entry with its update time
|
||||
type ConfigEntry[T any] struct {
|
||||
Data T
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// getConfigFromCache is a generic helper function to retrieve cached configs
|
||||
func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string, expirationTime time.Duration) (*T, bool) {
|
||||
if entry, ok := cache[serverID]; ok {
|
||||
if time.Since(entry.UpdatedAt) < expirationTime {
|
||||
@@ -157,7 +140,6 @@ func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// updateConfigInCache is a generic helper function to update cached configs
|
||||
func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID string, data T) {
|
||||
cache[serverID] = &ConfigEntry[T]{
|
||||
Data: data,
|
||||
@@ -166,7 +148,6 @@ func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID strin
|
||||
logging.Debug("Config cache SET for server ID: %s", serverID)
|
||||
}
|
||||
|
||||
// ServerConfigCache manages cached server configurations
|
||||
type ServerConfigCache struct {
|
||||
sync.RWMutex
|
||||
configuration map[string]*ConfigEntry[Configuration]
|
||||
@@ -177,7 +158,6 @@ type ServerConfigCache struct {
|
||||
config CacheConfig
|
||||
}
|
||||
|
||||
// NewServerConfigCache creates a new server configuration cache
|
||||
func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
|
||||
logging.Debug("Initializing new ServerConfigCache with expiration time: %v, throttle time: %v", config.ExpirationTime, config.ThrottleTime)
|
||||
return &ServerConfigCache{
|
||||
@@ -190,7 +170,6 @@ func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfiguration retrieves a cached configuration
|
||||
func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -198,7 +177,6 @@ func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, b
|
||||
return getConfigFromCache(c.configuration, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetAssistRules retrieves cached assist rules
|
||||
func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -206,7 +184,6 @@ func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool)
|
||||
return getConfigFromCache(c.assistRules, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetEvent retrieves cached event configuration
|
||||
func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -214,7 +191,6 @@ func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
|
||||
return getConfigFromCache(c.event, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetEventRules retrieves cached event rules
|
||||
func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -222,7 +198,6 @@ func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
|
||||
return getConfigFromCache(c.eventRules, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetSettings retrieves cached server settings
|
||||
func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -230,7 +205,6 @@ func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool)
|
||||
return getConfigFromCache(c.settings, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// UpdateConfiguration updates the configuration cache
|
||||
func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configuration) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -238,7 +212,6 @@ func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configur
|
||||
updateConfigInCache(c.configuration, serverID, config)
|
||||
}
|
||||
|
||||
// UpdateAssistRules updates the assist rules cache
|
||||
func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -246,7 +219,6 @@ func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules
|
||||
updateConfigInCache(c.assistRules, serverID, rules)
|
||||
}
|
||||
|
||||
// UpdateEvent updates the event configuration cache
|
||||
func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -254,7 +226,6 @@ func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
|
||||
updateConfigInCache(c.event, serverID, event)
|
||||
}
|
||||
|
||||
// UpdateEventRules updates the event rules cache
|
||||
func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -262,7 +233,6 @@ func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules)
|
||||
updateConfigInCache(c.eventRules, serverID, rules)
|
||||
}
|
||||
|
||||
// UpdateSettings updates the server settings cache
|
||||
func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSettings) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -270,7 +240,6 @@ func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSetti
|
||||
updateConfigInCache(c.settings, serverID, settings)
|
||||
}
|
||||
|
||||
// InvalidateServerCache removes all cached configurations for a specific server
|
||||
func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -283,7 +252,6 @@ func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
|
||||
delete(c.settings, serverID)
|
||||
}
|
||||
|
||||
// Clear removes all entries from the cache
|
||||
func (c *ServerConfigCache) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
@@ -13,17 +13,15 @@ import (
|
||||
type IntString int
|
||||
type IntBool int
|
||||
|
||||
// Config tracks configuration modifications
|
||||
type Config struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
|
||||
ConfigFile string `json:"configFile" gorm:"not null"` // e.g. "settings.json"
|
||||
ConfigFile string `json:"configFile" gorm:"not null"`
|
||||
OldConfig string `json:"oldConfig" gorm:"type:text"`
|
||||
NewConfig string `json:"newConfig" gorm:"type:text"`
|
||||
ChangedAt time.Time `json:"changedAt" gorm:"default:CURRENT_TIMESTAMP"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new config entries
|
||||
func (c *Config) BeforeCreate(tx *gorm.DB) error {
|
||||
if c.ID == uuid.Nil {
|
||||
c.ID = uuid.New()
|
||||
@@ -121,8 +119,6 @@ type Configuration struct {
|
||||
ConfigVersion IntString `json:"configVersion"`
|
||||
}
|
||||
|
||||
// Known configuration keys
|
||||
|
||||
func (i *IntBool) UnmarshalJSON(b []byte) error {
|
||||
var str int
|
||||
if err := json.Unmarshal(b, &str); err == nil && str <= 1 {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseFilter contains common filter fields that can be embedded in other filters
|
||||
type BaseFilter struct {
|
||||
Page int `query:"page"`
|
||||
PageSize int `query:"page_size"`
|
||||
@@ -15,18 +14,15 @@ type BaseFilter struct {
|
||||
SortDesc bool `query:"sort_desc"`
|
||||
}
|
||||
|
||||
// DateRangeFilter adds date range filtering capabilities
|
||||
type DateRangeFilter struct {
|
||||
StartDate time.Time `query:"start_date" time_format:"2006-01-02T15:04:05Z07:00"`
|
||||
EndDate time.Time `query:"end_date" time_format:"2006-01-02T15:04:05Z07:00"`
|
||||
}
|
||||
|
||||
// ServerBasedFilter adds server ID filtering capability
|
||||
type ServerBasedFilter struct {
|
||||
ServerID string `param:"id"`
|
||||
}
|
||||
|
||||
// ConfigFilter defines filtering options for Config queries
|
||||
type ConfigFilter struct {
|
||||
BaseFilter
|
||||
ServerBasedFilter
|
||||
@@ -34,13 +30,11 @@ type ConfigFilter struct {
|
||||
ChangedAt time.Time `query:"changed_at" time_format:"2006-01-02T15:04:05Z07:00"`
|
||||
}
|
||||
|
||||
// ApiFilter defines filtering options for Api queries
|
||||
type ServiceControlFilter struct {
|
||||
BaseFilter
|
||||
ServiceControl string `query:"serviceControl"`
|
||||
}
|
||||
|
||||
// MembershipFilter defines filtering options for User queries
|
||||
type MembershipFilter struct {
|
||||
BaseFilter
|
||||
Username string `query:"username"`
|
||||
@@ -48,36 +42,32 @@ type MembershipFilter struct {
|
||||
RoleID string `query:"role_id"`
|
||||
}
|
||||
|
||||
// Pagination returns the offset and limit for database queries
|
||||
func (f *BaseFilter) Pagination() (offset, limit int) {
|
||||
if f.Page < 1 {
|
||||
f.Page = 1
|
||||
}
|
||||
if f.PageSize < 1 {
|
||||
f.PageSize = 10 // Default page size
|
||||
f.PageSize = 10
|
||||
}
|
||||
offset = (f.Page - 1) * f.PageSize
|
||||
limit = f.PageSize
|
||||
return
|
||||
}
|
||||
|
||||
// GetSorting returns the sort field and direction for database queries
|
||||
func (f *BaseFilter) GetSorting() (field string, desc bool) {
|
||||
if f.SortBy == "" {
|
||||
return "id", false // Default sorting
|
||||
return "id", false
|
||||
}
|
||||
return f.SortBy, f.SortDesc
|
||||
}
|
||||
|
||||
// IsDateRangeValid checks if both dates are set and start date is before end date
|
||||
func (f *DateRangeFilter) IsDateRangeValid() bool {
|
||||
if f.StartDate.IsZero() || f.EndDate.IsZero() {
|
||||
return true // If either date is not set, consider it valid
|
||||
return true
|
||||
}
|
||||
return f.StartDate.Before(f.EndDate)
|
||||
}
|
||||
|
||||
// ApplyFilter applies the membership filter to a GORM query
|
||||
func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
if f.Username != "" {
|
||||
query = query.Where("username LIKE ?", "%"+f.Username+"%")
|
||||
@@ -93,12 +83,10 @@ func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
return query
|
||||
}
|
||||
|
||||
// Pagination returns the offset and limit for database queries
|
||||
func (f *MembershipFilter) Pagination() (offset, limit int) {
|
||||
return f.BaseFilter.Pagination()
|
||||
}
|
||||
|
||||
// GetSorting returns the sort field and direction for database queries
|
||||
func (f *MembershipFilter) GetSorting() (field string, desc bool) {
|
||||
return f.BaseFilter.GetSorting()
|
||||
}
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
package model
|
||||
|
||||
// Track represents a track and its capacity
|
||||
type Track struct {
|
||||
Name string `json:"track" gorm:"primaryKey;size:50"`
|
||||
UniquePitBoxes int `json:"unique_pit_boxes"`
|
||||
PrivateServerSlots int `json:"private_server_slots"`
|
||||
}
|
||||
|
||||
// CarModel represents a car model mapping
|
||||
type CarModel struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
CarModel string `json:"car_model"`
|
||||
}
|
||||
|
||||
// DriverCategory represents driver skill categories
|
||||
type DriverCategory struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
Category string `json:"category"`
|
||||
}
|
||||
|
||||
// CupCategory represents championship cup categories
|
||||
type CupCategory struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
Category string `json:"category"`
|
||||
}
|
||||
|
||||
// SessionType represents session types
|
||||
type SessionType struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
SessionType string `json:"session_type"`
|
||||
|
||||
@@ -32,8 +32,6 @@ type BaseModel struct {
|
||||
DateUpdated time.Time `json:"dateUpdated"`
|
||||
}
|
||||
|
||||
// Init
|
||||
// Initializes base model with DateCreated, DateUpdated, and Id values.
|
||||
func (cm *BaseModel) Init() {
|
||||
date := time.Now()
|
||||
cm.Id = uuid.NewString()
|
||||
|
||||
@@ -5,15 +5,13 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Permission represents an action that can be performed in the system.
|
||||
type Permission struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
Name string `json:"name" gorm:"unique_index;not null"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new credentials
|
||||
func (s *Permission) BeforeCreate(tx *gorm.DB) error {
|
||||
s.ID = uuid.New()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package model
|
||||
|
||||
// Permission constants
|
||||
const (
|
||||
ServerView = "server.view"
|
||||
ServerCreate = "server.create"
|
||||
@@ -27,7 +26,6 @@ const (
|
||||
MembershipEdit = "membership.edit"
|
||||
)
|
||||
|
||||
// AllPermissions returns a slice of all permission strings.
|
||||
func AllPermissions() []string {
|
||||
return []string{
|
||||
ServerView,
|
||||
|
||||
@@ -5,16 +5,14 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Role represents a user role in the system.
|
||||
type Role struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
Name string `json:"name" gorm:"unique_index;not null"`
|
||||
Permissions []Permission `json:"permissions" gorm:"many2many:role_permissions;"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new credentials
|
||||
func (s *Role) BeforeCreate(tx *gorm.DB) error {
|
||||
s.ID = uuid.New()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ const (
|
||||
ServiceNamePrefix = "ACC-Server"
|
||||
)
|
||||
|
||||
// Server represents an ACC server instance
|
||||
type ServerAPI struct {
|
||||
Name string `json:"name"`
|
||||
Status ServiceStatus `json:"status"`
|
||||
@@ -35,28 +34,27 @@ func (s *Server) ToServerAPI() *ServerAPI {
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents an ACC server instance
|
||||
type Server struct {
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Status ServiceStatus `json:"status" gorm:"-"`
|
||||
IP string `gorm:"not null" json:"-"`
|
||||
Port int `gorm:"not null" json:"-"`
|
||||
Path string `gorm:"not null" json:"path"` // e.g. "/acc/servers/server1/"
|
||||
ServiceName string `gorm:"not null" json:"serviceName"` // Windows service name
|
||||
Path string `gorm:"not null" json:"path"`
|
||||
ServiceName string `gorm:"not null" json:"serviceName"`
|
||||
State *ServerState `gorm:"-" json:"state"`
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
FromSteamCMD bool `gorm:"not null; default:true" json:"-"`
|
||||
}
|
||||
|
||||
type PlayerState struct {
|
||||
CarID int // Car ID in broadcast packets
|
||||
DriverName string // Optional: pulled from registration packet
|
||||
CarID int
|
||||
DriverName string
|
||||
TeamName string
|
||||
CarModel string
|
||||
CurrentLap int
|
||||
LastLapTime int // in milliseconds
|
||||
BestLapTime int // in milliseconds
|
||||
LastLapTime int
|
||||
BestLapTime int
|
||||
Position int
|
||||
ConnectedAt time.Time
|
||||
DisconnectedAt *time.Time
|
||||
@@ -67,8 +65,6 @@ type State struct {
|
||||
Session string `json:"session"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
PlayerCount int `json:"playerCount"`
|
||||
// Players map[int]*PlayerState
|
||||
// etc.
|
||||
}
|
||||
|
||||
type ServerState struct {
|
||||
@@ -79,11 +75,8 @@ type ServerState struct {
|
||||
Track string `json:"track"`
|
||||
MaxConnections int `json:"maxConnections"`
|
||||
SessionDurationMinutes int `json:"sessionDurationMinutes"`
|
||||
// Players map[int]*PlayerState
|
||||
// etc.
|
||||
}
|
||||
|
||||
// ServerFilter defines filtering options for Server queries
|
||||
type ServerFilter struct {
|
||||
BaseFilter
|
||||
ServerBasedFilter
|
||||
@@ -92,9 +85,7 @@ type ServerFilter struct {
|
||||
Status string `query:"status"`
|
||||
}
|
||||
|
||||
// ApplyFilter implements the Filterable interface
|
||||
func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
// Apply server filter
|
||||
if f.ServerID != "" {
|
||||
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
|
||||
query = query.Where("id = ?", serverUUID)
|
||||
@@ -110,16 +101,13 @@ func (s *Server) GenerateUUID() {
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating a new server
|
||||
func (s *Server) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.Name == "" {
|
||||
return errors.New("server name is required")
|
||||
}
|
||||
|
||||
// Generate UUID if not set
|
||||
s.GenerateUUID()
|
||||
|
||||
// Generate service name and config path if not set
|
||||
if s.ServiceName == "" {
|
||||
s.ServiceName = s.GenerateServiceName()
|
||||
}
|
||||
@@ -127,7 +115,6 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
|
||||
s.Path = s.GenerateServerPath(BaseServerPath)
|
||||
}
|
||||
|
||||
// Set creation date if not set
|
||||
if s.DateCreated.IsZero() {
|
||||
s.DateCreated = time.Now().UTC()
|
||||
}
|
||||
@@ -135,19 +122,14 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateServiceName creates a unique service name based on the server name
|
||||
func (s *Server) GenerateServiceName() string {
|
||||
// If ID is set, use it
|
||||
if s.ID != uuid.Nil {
|
||||
return fmt.Sprintf("%s-%s", ServiceNamePrefix, s.ID.String()[:8])
|
||||
}
|
||||
// Otherwise use a timestamp-based unique identifier
|
||||
return fmt.Sprintf("%s-%d", ServiceNamePrefix, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// GenerateServerPath creates the config path based on the service name
|
||||
func (s *Server) GenerateServerPath(steamCMDPath string) string {
|
||||
// Ensure service name is set
|
||||
if s.ServiceName == "" {
|
||||
s.ServiceName = s.GenerateServiceName()
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ const (
|
||||
StatusRunning
|
||||
)
|
||||
|
||||
// String converts the ServiceStatus to its string representation
|
||||
func (s ServiceStatus) String() string {
|
||||
switch s {
|
||||
case StatusRunning:
|
||||
@@ -35,7 +34,6 @@ func (s ServiceStatus) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseServiceStatus converts a string to ServiceStatus
|
||||
func ParseServiceStatus(s string) ServiceStatus {
|
||||
switch s {
|
||||
case "SERVICE_RUNNING":
|
||||
@@ -53,31 +51,24 @@ func ParseServiceStatus(s string) ServiceStatus {
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler interface
|
||||
func (s ServiceStatus) MarshalJSON() ([]byte, error) {
|
||||
// Return the numeric value instead of string
|
||||
return []byte(strconv.Itoa(int(s))), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler interface
|
||||
func (s *ServiceStatus) UnmarshalJSON(data []byte) error {
|
||||
// Try to parse as number first
|
||||
if i, err := strconv.Atoi(string(data)); err == nil {
|
||||
*s = ServiceStatus(i)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback to string parsing for backward compatibility
|
||||
str := string(data)
|
||||
if len(str) >= 2 {
|
||||
// Remove quotes if present
|
||||
str = str[1 : len(str)-1]
|
||||
}
|
||||
*s = ParseServiceStatus(str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface
|
||||
func (s *ServiceStatus) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*s = StatusUnknown
|
||||
@@ -99,7 +90,6 @@ func (s *ServiceStatus) Scan(value interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface
|
||||
func (s ServiceStatus) Value() (driver.Value, error) {
|
||||
return s.String(), nil
|
||||
}
|
||||
|
||||
@@ -10,27 +10,22 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// StateHistoryFilter combines common filter capabilities
|
||||
type StateHistoryFilter struct {
|
||||
ServerBasedFilter // Adds server ID from path parameter
|
||||
DateRangeFilter // Adds date range filtering
|
||||
ServerBasedFilter
|
||||
DateRangeFilter
|
||||
|
||||
// Additional fields specific to state history
|
||||
Session TrackSession `query:"session"`
|
||||
MinPlayers *int `query:"min_players"`
|
||||
MaxPlayers *int `query:"max_players"`
|
||||
}
|
||||
|
||||
// ApplyFilter implements the Filterable interface
|
||||
func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
// Apply server filter
|
||||
if f.ServerID != "" {
|
||||
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
|
||||
query = query.Where("server_id = ?", serverUUID)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply date range filter if set
|
||||
timeZero := time.Time{}
|
||||
if f.StartDate != timeZero {
|
||||
query = query.Where("date_created >= ?", f.StartDate)
|
||||
@@ -39,12 +34,10 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
query = query.Where("date_created <= ?", f.EndDate)
|
||||
}
|
||||
|
||||
// Apply session filter if set
|
||||
if f.Session != "" {
|
||||
query = query.Where("session = ?", f.Session)
|
||||
}
|
||||
|
||||
// Apply player count filters if set
|
||||
if f.MinPlayers != nil {
|
||||
query = query.Where("player_count >= ?", *f.MinPlayers)
|
||||
}
|
||||
@@ -114,10 +107,9 @@ type StateHistory struct {
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
SessionDurationMinutes int `json:"sessionDurationMinutes"`
|
||||
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` // Unique identifier for each session/event
|
||||
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new state history entries
|
||||
func (sh *StateHistory) BeforeCreate(tx *gorm.DB) error {
|
||||
if sh.ID == uuid.Nil {
|
||||
sh.ID = uuid.New()
|
||||
|
||||
@@ -21,7 +21,7 @@ type StateHistoryStats struct {
|
||||
AveragePlayers float64 `json:"averagePlayers"`
|
||||
PeakPlayers int `json:"peakPlayers"`
|
||||
TotalSessions int `json:"totalSessions"`
|
||||
TotalPlaytime int `json:"totalPlaytime" gorm:"-"` // in minutes
|
||||
TotalPlaytime int `json:"totalPlaytime" gorm:"-"`
|
||||
PlayerCountOverTime []PlayerCountPoint `json:"playerCountOverTime" gorm:"-"`
|
||||
SessionTypes []SessionCount `json:"sessionTypes" gorm:"-"`
|
||||
DailyActivity []DailyActivity `json:"dailyActivity" gorm:"-"`
|
||||
|
||||
@@ -16,21 +16,18 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SteamCredentials represents stored Steam login credentials
|
||||
type SteamCredentials struct {
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
|
||||
Username string `gorm:"not null" json:"username"`
|
||||
Password string `gorm:"not null" json:"-"` // Encrypted, not exposed in JSON
|
||||
Password string `gorm:"not null" json:"-"`
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
LastUpdated time.Time `json:"lastUpdated"`
|
||||
}
|
||||
|
||||
// TableName specifies the table name for GORM
|
||||
func (SteamCredentials) TableName() string {
|
||||
return "steam_credentials"
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new credentials
|
||||
func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.ID == uuid.Nil {
|
||||
s.ID = uuid.New()
|
||||
@@ -42,7 +39,6 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
|
||||
}
|
||||
s.LastUpdated = now
|
||||
|
||||
// Encrypt password before saving
|
||||
encrypted, err := EncryptPassword(s.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -52,11 +48,9 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeUpdate is a GORM hook that runs before updating credentials
|
||||
func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
|
||||
s.LastUpdated = time.Now().UTC()
|
||||
|
||||
// Only encrypt if password field is being updated
|
||||
if tx.Statement.Changed("Password") {
|
||||
encrypted, err := EncryptPassword(s.Password)
|
||||
if err != nil {
|
||||
@@ -68,9 +62,7 @@ func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that runs after fetching credentials
|
||||
func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
|
||||
// Decrypt password after fetching
|
||||
if s.Password != "" {
|
||||
decrypted, err := DecryptPassword(s.Password)
|
||||
if err != nil {
|
||||
@@ -81,18 +73,15 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the credentials are valid with enhanced security checks
|
||||
func (s *SteamCredentials) Validate() error {
|
||||
if s.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
|
||||
// Enhanced username validation
|
||||
if len(s.Username) < 3 || len(s.Username) > 64 {
|
||||
return errors.New("username must be between 3 and 64 characters")
|
||||
}
|
||||
|
||||
// Check for valid characters in username (alphanumeric, underscore, hyphen)
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched {
|
||||
return errors.New("username contains invalid characters")
|
||||
}
|
||||
@@ -101,7 +90,6 @@ func (s *SteamCredentials) Validate() error {
|
||||
return errors.New("password is required")
|
||||
}
|
||||
|
||||
// Basic password validation
|
||||
if len(s.Password) < 6 {
|
||||
return errors.New("password must be at least 6 characters long")
|
||||
}
|
||||
@@ -110,7 +98,6 @@ func (s *SteamCredentials) Validate() error {
|
||||
return errors.New("password is too long")
|
||||
}
|
||||
|
||||
// Check for obvious weak passwords
|
||||
weakPasswords := []string{"password", "123456", "steam", "admin", "user"}
|
||||
lowerPass := strings.ToLower(s.Password)
|
||||
for _, weak := range weakPasswords {
|
||||
@@ -122,8 +109,6 @@ func (s *SteamCredentials) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEncryptionKey returns the encryption key from config.
|
||||
// The key is loaded from the ENCRYPTION_KEY environment variable.
|
||||
func GetEncryptionKey() []byte {
|
||||
key := []byte(configs.EncryptionKey)
|
||||
if len(key) != 32 {
|
||||
@@ -132,7 +117,6 @@ func GetEncryptionKey() []byte {
|
||||
return key
|
||||
}
|
||||
|
||||
// EncryptPassword encrypts a password using AES-256-GCM with enhanced security
|
||||
func EncryptPassword(password string) (string, error) {
|
||||
if password == "" {
|
||||
return "", errors.New("password cannot be empty")
|
||||
@@ -148,33 +132,27 @@ func EncryptPassword(password string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a new GCM cipher
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a cryptographically secure nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Encrypt the password with authenticated encryption
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil)
|
||||
|
||||
// Return base64 encoded encrypted password
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptPassword decrypts an encrypted password with enhanced validation
|
||||
func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", errors.New("encrypted password cannot be empty")
|
||||
}
|
||||
|
||||
// Validate base64 format
|
||||
if len(encryptedPassword) < 24 { // Minimum reasonable length
|
||||
if len(encryptedPassword) < 24 {
|
||||
return "", errors.New("invalid encrypted password format")
|
||||
}
|
||||
|
||||
@@ -184,13 +162,11 @@ func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a new GCM cipher
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Decode base64 encoded password
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid base64 encoding")
|
||||
@@ -207,7 +183,6 @@ func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
return "", errors.New("decryption failed - invalid ciphertext or key")
|
||||
}
|
||||
|
||||
// Validate decrypted content
|
||||
decrypted := string(plaintext)
|
||||
if len(decrypted) == 0 || len(decrypted) > 1024 {
|
||||
return "", errors.New("invalid decrypted password")
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User represents a user account in the system.
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
Username string `json:"username" gorm:"unique_index;not null"`
|
||||
@@ -17,16 +16,13 @@ type User struct {
|
||||
Role Role `json:"role"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new users
|
||||
func (s *User) BeforeCreate(tx *gorm.DB) error {
|
||||
s.ID = uuid.New()
|
||||
|
||||
// Validate password strength
|
||||
if err := password.ValidatePasswordStrength(s.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash password before saving
|
||||
hashed, err := password.HashPassword(s.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -36,11 +32,8 @@ func (s *User) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeUpdate is a GORM hook that runs before updating users
|
||||
func (s *User) BeforeUpdate(tx *gorm.DB) error {
|
||||
// Only hash if password field is being updated
|
||||
if tx.Statement.Changed("Password") {
|
||||
// Validate password strength
|
||||
if err := password.ValidatePasswordStrength(s.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -55,14 +48,10 @@ func (s *User) BeforeUpdate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that runs after fetching users
|
||||
func (s *User) AfterFind(tx *gorm.DB) error {
|
||||
// Password remains hashed - never decrypt
|
||||
// This hook is kept for potential future use
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the user data is valid
|
||||
func (s *User) Validate() error {
|
||||
if s.Username == "" {
|
||||
return errors.New("username is required")
|
||||
@@ -73,7 +62,6 @@ func (s *User) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a plain text password against the stored hash
|
||||
func (s *User) VerifyPassword(plainPassword string) error {
|
||||
return password.VerifyPassword(s.Password, plainPassword)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ServerCreationStep represents the steps in server creation process
|
||||
type ServerCreationStep string
|
||||
|
||||
const (
|
||||
@@ -18,7 +17,6 @@ const (
|
||||
StepCompleted ServerCreationStep = "completed"
|
||||
)
|
||||
|
||||
// StepStatus represents the status of a step
|
||||
type StepStatus string
|
||||
|
||||
const (
|
||||
@@ -28,7 +26,6 @@ const (
|
||||
StatusFailed StepStatus = "failed"
|
||||
)
|
||||
|
||||
// WebSocketMessageType represents different types of WebSocket messages
|
||||
type WebSocketMessageType string
|
||||
|
||||
const (
|
||||
@@ -38,7 +35,6 @@ const (
|
||||
MessageTypeComplete WebSocketMessageType = "complete"
|
||||
)
|
||||
|
||||
// WebSocketMessage is the base structure for all WebSocket messages
|
||||
type WebSocketMessage struct {
|
||||
Type WebSocketMessageType `json:"type"`
|
||||
ServerID *uuid.UUID `json:"server_id,omitempty"`
|
||||
@@ -46,7 +42,6 @@ type WebSocketMessage struct {
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// StepMessage represents a step update message
|
||||
type StepMessage struct {
|
||||
Step ServerCreationStep `json:"step"`
|
||||
Status StepStatus `json:"status"`
|
||||
@@ -54,26 +49,22 @@ type StepMessage struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// SteamOutputMessage represents SteamCMD output
|
||||
type SteamOutputMessage struct {
|
||||
Output string `json:"output"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
// ErrorMessage represents an error message
|
||||
type ErrorMessage struct {
|
||||
Error string `json:"error"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteMessage represents completion message
|
||||
type CompleteMessage struct {
|
||||
ServerID uuid.UUID `json:"server_id"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// GetStepDescription returns a human-readable description for each step
|
||||
func GetStepDescription(step ServerCreationStep) string {
|
||||
descriptions := map[ServerCreationStep]string{
|
||||
StepValidation: "Validating server configuration",
|
||||
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseRepository provides generic CRUD operations for any model
|
||||
type BaseRepository[T any, F any] struct {
|
||||
db *gorm.DB
|
||||
modelType T
|
||||
}
|
||||
|
||||
// NewBaseRepository creates a new base repository for the given model type
|
||||
func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] {
|
||||
return &BaseRepository[T, F]{
|
||||
db: db,
|
||||
@@ -21,23 +19,19 @@ func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F]
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll retrieves all records based on the filter
|
||||
func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, error) {
|
||||
result := new([]T)
|
||||
query := r.db.WithContext(ctx).Model(&r.modelType)
|
||||
|
||||
// Apply filter conditions if filter implements Filterable
|
||||
if filterable, ok := any(filter).(Filterable); ok {
|
||||
query = filterable.ApplyFilter(query)
|
||||
}
|
||||
|
||||
// Apply pagination if filter implements Pageable
|
||||
if pageable, ok := any(filter).(Pageable); ok {
|
||||
offset, limit := pageable.Pagination()
|
||||
query = query.Offset(offset).Limit(limit)
|
||||
}
|
||||
|
||||
// Apply sorting if filter implements Sortable
|
||||
if sortable, ok := any(filter).(Sortable); ok {
|
||||
field, desc := sortable.GetSorting()
|
||||
if desc {
|
||||
@@ -54,7 +48,6 @@ func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, err
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a single record by ID
|
||||
func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, error) {
|
||||
result := new(T)
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(result).Error; err != nil {
|
||||
@@ -66,7 +59,6 @@ func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T,
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Insert creates a new record
|
||||
func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
|
||||
if err := r.db.WithContext(ctx).Create(model).Error; err != nil {
|
||||
return fmt.Errorf("error creating record: %w", err)
|
||||
@@ -74,7 +66,6 @@ func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing record
|
||||
func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
|
||||
if err := r.db.WithContext(ctx).Save(model).Error; err != nil {
|
||||
return fmt.Errorf("error updating record: %w", err)
|
||||
@@ -82,7 +73,6 @@ func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a record by ID
|
||||
func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error {
|
||||
if err := r.db.WithContext(ctx).Delete(new(T), id).Error; err != nil {
|
||||
return fmt.Errorf("error deleting record: %w", err)
|
||||
@@ -90,7 +80,6 @@ func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the total number of records matching the filter
|
||||
func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&r.modelType)
|
||||
@@ -106,8 +95,6 @@ func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, err
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Interfaces for filter capabilities
|
||||
|
||||
type Filterable interface {
|
||||
ApplyFilter(*gorm.DB) *gorm.DB
|
||||
}
|
||||
|
||||
@@ -17,13 +17,11 @@ func NewConfigRepository(db *gorm.DB) *ConfigRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig updates or creates a Config record
|
||||
func (r *ConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
|
||||
if err := r.Update(ctx, config); err != nil {
|
||||
// If update fails, try to insert
|
||||
if err := r.Insert(ctx, config); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,20 +8,16 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MembershipRepository handles database operations for users, roles, and permissions.
|
||||
type MembershipRepository struct {
|
||||
*BaseRepository[model.User, model.MembershipFilter]
|
||||
}
|
||||
|
||||
// NewMembershipRepository creates a new MembershipRepository.
|
||||
func NewMembershipRepository(db *gorm.DB) *MembershipRepository {
|
||||
return &MembershipRepository{
|
||||
BaseRepository: NewBaseRepository[model.User, model.MembershipFilter](db, model.User{}),
|
||||
}
|
||||
}
|
||||
|
||||
// FindUserByUsername finds a user by their username.
|
||||
// It preloads the user's role and the role's permissions.
|
||||
func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -32,7 +28,6 @@ func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindUserByIDWithPermissions finds a user by their ID and preloads Role and Permissions.
|
||||
func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, userID string) (*model.User, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -43,13 +38,11 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Create(user).Error
|
||||
}
|
||||
|
||||
// FindRoleByName finds a role by its name.
|
||||
func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) (*model.Role, error) {
|
||||
var role model.Role
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -60,13 +53,11 @@ func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string)
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// CreateRole creates a new role.
|
||||
func (r *MembershipRepository) CreateRole(ctx context.Context, role *model.Role) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Create(role).Error
|
||||
}
|
||||
|
||||
// FindPermissionByName finds a permission by its name.
|
||||
func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name string) (*model.Permission, error) {
|
||||
var permission model.Permission
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -77,19 +68,16 @@ func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name st
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// CreatePermission creates a new permission.
|
||||
func (r *MembershipRepository) CreatePermission(ctx context.Context, permission *model.Permission) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Create(permission).Error
|
||||
}
|
||||
|
||||
// AssignPermissionsToRole assigns a set of permissions to a role.
|
||||
func (r *MembershipRepository) AssignPermissionsToRole(ctx context.Context, role *model.Role, permissions []model.Permission) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Model(role).Association("Permissions").Replace(permissions)
|
||||
}
|
||||
|
||||
// GetUserPermissions retrieves all permissions for a given user ID.
|
||||
func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -106,7 +94,6 @@ func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uu
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// ListUsers retrieves all users.
|
||||
func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, error) {
|
||||
var users []*model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -114,13 +101,11 @@ func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, er
|
||||
return users, err
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user.
|
||||
func (r *MembershipRepository) DeleteUser(ctx context.Context, userID uuid.UUID) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Delete(&model.User{}, "id = ?", userID).Error
|
||||
}
|
||||
|
||||
// FindUserByID finds a user by their ID.
|
||||
func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -131,13 +116,11 @@ func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUI
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's details in the database.
|
||||
func (r *MembershipRepository) UpdateUser(ctx context.Context, user *model.User) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Save(user).Error
|
||||
}
|
||||
|
||||
// FindRoleByID finds a role by its ID.
|
||||
func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUID) (*model.Role, error) {
|
||||
var role model.Role
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -148,12 +131,10 @@ func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUI
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// ListUsersWithFilter retrieves users based on the membership filter.
|
||||
func (r *MembershipRepository) ListUsersWithFilter(ctx context.Context, filter *model.MembershipFilter) (*[]model.User, error) {
|
||||
return r.BaseRepository.GetAll(ctx, filter)
|
||||
}
|
||||
|
||||
// ListRoles retrieves all roles.
|
||||
func (r *MembershipRepository) ListRoles(ctx context.Context) ([]*model.Role, error) {
|
||||
var roles []*model.Role
|
||||
db := r.db.WithContext(ctx)
|
||||
|
||||
@@ -10,11 +10,7 @@ import (
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// InitializeRepositories
|
||||
// Initializes Dependency Injection modules for repositories
|
||||
//
|
||||
// Args:
|
||||
// *dig.Container: Dig Container
|
||||
// *dig.Container: Dig Container
|
||||
func InitializeRepositories(c *dig.Container) {
|
||||
c.Provide(NewServiceControlRepository)
|
||||
c.Provide(NewStateHistoryRepository)
|
||||
@@ -24,16 +20,14 @@ func InitializeRepositories(c *dig.Container) {
|
||||
c.Provide(NewSteamCredentialsRepository)
|
||||
c.Provide(NewMembershipRepository)
|
||||
|
||||
// Provide the Steam2FAManager as a singleton
|
||||
if err := c.Provide(func() *model.Steam2FAManager {
|
||||
manager := model.NewSteam2FAManager()
|
||||
|
||||
// Use graceful shutdown manager for cleanup goroutine
|
||||
|
||||
shutdownManager := graceful.GetManager()
|
||||
shutdownManager.RunGoroutine(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -43,7 +37,7 @@ func InitializeRepositories(c *dig.Container) {
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
return manager
|
||||
}); err != nil {
|
||||
logging.Panic("unable to initialize steam 2fa manager")
|
||||
|
||||
@@ -20,10 +20,6 @@ func NewServerRepository(db *gorm.DB) *ServerRepository {
|
||||
return repo
|
||||
}
|
||||
|
||||
// GetFirstByServiceName
|
||||
// Gets first row from Server table.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// model.ServerModel: Server object from database.
|
||||
|
||||
@@ -18,17 +18,14 @@ func NewStateHistoryRepository(db *gorm.DB) *StateHistoryRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll retrieves all state history records with the given filter
|
||||
func (r *StateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
|
||||
return r.BaseRepository.GetAll(ctx, filter)
|
||||
}
|
||||
|
||||
// Insert creates a new state history record
|
||||
func (r *StateHistoryRepository) Insert(ctx context.Context, model *model.StateHistory) error {
|
||||
return r.BaseRepository.Insert(ctx, model)
|
||||
}
|
||||
|
||||
// GetLastSessionID gets the last session ID for a server
|
||||
func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
|
||||
var lastSession model.StateHistory
|
||||
result := r.BaseRepository.db.WithContext(ctx).
|
||||
@@ -38,7 +35,7 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
|
||||
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return uuid.Nil, nil // Return nil UUID if no sessions found
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
return uuid.Nil, result.Error
|
||||
}
|
||||
@@ -46,10 +43,8 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
|
||||
return lastSession.SessionID, nil
|
||||
}
|
||||
|
||||
// GetSummaryStats calculates peak players, total sessions, and average players.
|
||||
func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
|
||||
var stats model.StateHistoryStats
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return model.StateHistoryStats{}, err
|
||||
@@ -73,12 +68,10 @@ func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *mo
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetTotalPlaytime calculates the total playtime in minutes.
|
||||
func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
|
||||
var totalPlaytime struct {
|
||||
TotalMinutes float64
|
||||
}
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -100,10 +93,8 @@ func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *m
|
||||
return int(totalPlaytime.TotalMinutes), nil
|
||||
}
|
||||
|
||||
// GetPlayerCountOverTime gets downsampled player count data.
|
||||
func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
|
||||
var points []model.PlayerCountPoint
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return points, err
|
||||
@@ -122,10 +113,8 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil
|
||||
return points, err
|
||||
}
|
||||
|
||||
// GetSessionTypes counts sessions by type.
|
||||
func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
|
||||
var sessionTypes []model.SessionCount
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return sessionTypes, err
|
||||
@@ -145,10 +134,8 @@ func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *mo
|
||||
return sessionTypes, err
|
||||
}
|
||||
|
||||
// GetDailyActivity counts sessions per day.
|
||||
func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
|
||||
var dailyActivity []model.DailyActivity
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return dailyActivity, err
|
||||
@@ -167,10 +154,8 @@ func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *m
|
||||
return dailyActivity, err
|
||||
}
|
||||
|
||||
// GetRecentSessions retrieves the 10 most recent sessions.
|
||||
func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
|
||||
var recentSessions []model.RecentSession
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return recentSessions, err
|
||||
|
||||
@@ -77,8 +77,8 @@ func NewConfigService(repository *repository.ConfigRepository, serverRepository
|
||||
repository: repository,
|
||||
serverRepository: serverRepository,
|
||||
configCache: model.NewServerConfigCache(model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute, // Cache configs for 5 minutes
|
||||
ThrottleTime: 1 * time.Second, // Prevent rapid re-reads
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
DefaultStatus: model.StatusUnknown,
|
||||
}),
|
||||
}
|
||||
@@ -88,10 +88,6 @@ func (as *ConfigService) SetServerService(serverService *ServerService) {
|
||||
as.serverService = serverService
|
||||
}
|
||||
|
||||
// UpdateConfig
|
||||
// Updates physical config file and caches it in database.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -103,7 +99,6 @@ func (as *ConfigService) UpdateConfig(ctx *fiber.Ctx, body *map[string]interface
|
||||
return as.updateConfigInternal(ctx.UserContext(), serverID, configFile, body, override)
|
||||
}
|
||||
|
||||
// updateConfigInternal handles the actual config update logic without Fiber dependencies
|
||||
func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID string, configFile string, body *map[string]interface{}, override bool) (*model.Config, error) {
|
||||
serverUUID, err := uuid.Parse(serverID)
|
||||
if err != nil {
|
||||
@@ -117,17 +112,14 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
|
||||
return nil, fmt.Errorf("server not found")
|
||||
}
|
||||
|
||||
// Read existing config
|
||||
configPath := filepath.Join(server.GetConfigPath(), configFile)
|
||||
oldData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Create directory if it doesn't exist
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create empty JSON file
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -142,7 +134,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write new config
|
||||
newData, err := json.Marshal(&body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -168,12 +159,9 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invalidate all configs for this server since configs can be interdependent
|
||||
as.configCache.InvalidateServerCache(serverID)
|
||||
|
||||
as.serverService.StartAccServerRuntime(server)
|
||||
|
||||
// Log change
|
||||
return as.repository.UpdateConfig(ctx, &model.Config{
|
||||
ServerID: serverUUID,
|
||||
ConfigFile: configFile,
|
||||
@@ -183,10 +171,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
|
||||
}), nil
|
||||
}
|
||||
|
||||
// GetConfig
|
||||
// Gets physical config file and caches it in database.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -202,7 +186,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
return nil, fiber.NewError(404, "Server not found")
|
||||
}
|
||||
|
||||
// Try to get from cache based on config file type
|
||||
switch configFile {
|
||||
case ConfigurationJson:
|
||||
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
|
||||
@@ -233,7 +216,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
|
||||
logging.Debug("Cache miss for server ID: %s, file: %s - loading from disk", serverIDStr, configFile)
|
||||
|
||||
// Not in cache, load from disk
|
||||
configPath := filepath.Join(server.GetConfigPath(), configFile)
|
||||
decoder := DecodeFileName(configFile)
|
||||
if decoder == nil {
|
||||
@@ -244,7 +226,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
logging.Debug("Config file not found, creating default for server ID: %s, file: %s", serverIDStr, configFile)
|
||||
// Return empty config if file doesn't exist
|
||||
switch configFile {
|
||||
case ConfigurationJson:
|
||||
return &model.Configuration{}, nil
|
||||
@@ -261,7 +242,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the loaded config
|
||||
switch configFile {
|
||||
case ConfigurationJson:
|
||||
as.configCache.UpdateConfiguration(serverIDStr, *config.(*model.Configuration))
|
||||
@@ -279,8 +259,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfigs
|
||||
// Gets all configurations for a server, using cache when possible.
|
||||
func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, error) {
|
||||
serverID := ctx.Params("id")
|
||||
|
||||
@@ -298,7 +276,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
|
||||
logging.Info("Loading configs for server ID: %s at path: %s", serverIDStr, server.GetConfigPath())
|
||||
configs := &model.Configurations{}
|
||||
|
||||
// Load configuration
|
||||
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
|
||||
logging.Debug("Using cached configuration for server %s", serverIDStr)
|
||||
configs.Configuration = *cached
|
||||
@@ -313,7 +290,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
|
||||
as.configCache.UpdateConfiguration(serverIDStr, config)
|
||||
}
|
||||
|
||||
// Load assist rules
|
||||
if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok {
|
||||
logging.Debug("Using cached assist rules for server %s", serverIDStr)
|
||||
configs.AssistRules = *cached
|
||||
@@ -328,7 +304,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
|
||||
as.configCache.UpdateAssistRules(serverIDStr, rules)
|
||||
}
|
||||
|
||||
// Load event config
|
||||
if cached, ok := as.configCache.GetEvent(serverIDStr); ok {
|
||||
logging.Debug("Using cached event config for server %s", serverIDStr)
|
||||
configs.Event = *cached
|
||||
@@ -344,7 +319,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
|
||||
as.configCache.UpdateEvent(serverIDStr, event)
|
||||
}
|
||||
|
||||
// Load event rules
|
||||
if cached, ok := as.configCache.GetEventRules(serverIDStr); ok {
|
||||
logging.Debug("Using cached event rules for server %s", serverIDStr)
|
||||
configs.EventRules = *cached
|
||||
@@ -359,7 +333,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration
|
||||
as.configCache.UpdateEventRules(serverIDStr, rules)
|
||||
}
|
||||
|
||||
// Load settings
|
||||
if cached, ok := as.configCache.GetSettings(serverIDStr); ok {
|
||||
logging.Debug("Using cached settings for server %s", serverIDStr)
|
||||
configs.Settings = *cached
|
||||
@@ -475,9 +448,7 @@ func (as *ConfigService) GetConfiguration(server *model.Server) (*model.Configur
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// SaveConfiguration saves the configuration for a server
|
||||
func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.Configuration) error {
|
||||
// Convert config to map for UpdateConfig
|
||||
configMap := make(map[string]interface{})
|
||||
configBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
@@ -487,7 +458,6 @@ func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.C
|
||||
return fmt.Errorf("failed to unmarshal configuration: %v", err)
|
||||
}
|
||||
|
||||
// Update the configuration using the internal method
|
||||
_, err = as.updateConfigInternal(context.Background(), server.ID.String(), ConfigurationJson, &configMap, true)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -96,11 +96,9 @@ func (s *FirewallService) DeleteServerRules(serverName string, tcpPorts, udpPort
|
||||
}
|
||||
|
||||
func (s *FirewallService) UpdateServerRules(serverName string, tcpPorts, udpPorts []int) error {
|
||||
// First delete existing rules
|
||||
if err := s.DeleteServerRules(serverName, tcpPorts, udpPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then create new rules
|
||||
return s.CreateServerRules(serverName, tcpPorts, udpPorts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CacheInvalidator interface for cache invalidation
|
||||
type CacheInvalidator interface {
|
||||
InvalidateUserPermissions(userID string)
|
||||
InvalidateAllUserPermissions()
|
||||
}
|
||||
|
||||
// MembershipService provides business logic for membership-related operations.
|
||||
type MembershipService struct {
|
||||
repo *repository.MembershipRepository
|
||||
cacheInvalidator CacheInvalidator
|
||||
@@ -26,29 +24,25 @@ type MembershipService struct {
|
||||
openJwtHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewMembershipService creates a new MembershipService.
|
||||
func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService {
|
||||
return &MembershipService{
|
||||
repo: repo,
|
||||
cacheInvalidator: nil, // Will be set later via SetCacheInvalidator
|
||||
cacheInvalidator: nil,
|
||||
jwtHandler: jwtHandler,
|
||||
openJwtHandler: openJwtHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCacheInvalidator sets the cache invalidator after service initialization
|
||||
func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) {
|
||||
s.cacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT.
|
||||
func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) {
|
||||
user, err := s.repo.FindUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Use secure password verification with constant-time comparison
|
||||
if err := user.VerifyPassword(password); err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
@@ -56,7 +50,6 @@ func (s *MembershipService) HandleLogin(ctx context.Context, username, password
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT.
|
||||
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
user, err := s.HandleLogin(ctx, username, password)
|
||||
if err != nil {
|
||||
@@ -70,7 +63,6 @@ func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string
|
||||
return s.openJwtHandler.GenerateToken(userId)
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
func (s *MembershipService) CreateUser(ctx context.Context, username, password, roleName string) (*model.User, error) {
|
||||
|
||||
role, err := s.repo.FindRoleByName(ctx, roleName)
|
||||
@@ -94,43 +86,35 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ListUsers retrieves all users.
|
||||
func (s *MembershipService) ListUsers(ctx context.Context) ([]*model.User, error) {
|
||||
return s.repo.ListUsers(ctx)
|
||||
}
|
||||
|
||||
// GetUser retrieves a single user by ID.
|
||||
func (s *MembershipService) GetUser(ctx context.Context, userID uuid.UUID) (*model.User, error) {
|
||||
return s.repo.FindUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserWithPermissions retrieves a single user by ID with their role and permissions.
|
||||
func (s *MembershipService) GetUserWithPermissions(ctx context.Context, userID string) (*model.User, error) {
|
||||
return s.repo.FindUserByIDWithPermissions(ctx, userID)
|
||||
}
|
||||
|
||||
// UpdateUserRequest defines the request body for updating a user.
|
||||
type UpdateUserRequest struct {
|
||||
Username *string `json:"username"`
|
||||
Password *string `json:"password"`
|
||||
RoleID *uuid.UUID `json:"roleId"`
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user with validation to prevent Super Admin deletion.
|
||||
func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) error {
|
||||
// Get user with role information
|
||||
user, err := s.repo.FindUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
// Get role to check if it's Super Admin
|
||||
role, err := s.repo.FindRoleByID(ctx, user.RoleID)
|
||||
if err != nil {
|
||||
return errors.New("user role not found")
|
||||
}
|
||||
|
||||
// Prevent deletion of Super Admin users
|
||||
if role.Name == "Super Admin" {
|
||||
return errors.New("cannot delete Super Admin user")
|
||||
}
|
||||
@@ -140,7 +124,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate cache for deleted user
|
||||
if s.cacheInvalidator != nil {
|
||||
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
|
||||
}
|
||||
@@ -149,7 +132,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's details.
|
||||
func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, req UpdateUserRequest) (*model.User, error) {
|
||||
user, err := s.repo.FindUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -161,12 +143,10 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
|
||||
}
|
||||
|
||||
if req.Password != nil && *req.Password != "" {
|
||||
// Password will be automatically hashed in BeforeUpdate hook
|
||||
user.Password = *req.Password
|
||||
}
|
||||
|
||||
if req.RoleID != nil {
|
||||
// Check if role exists
|
||||
_, err := s.repo.FindRoleByID(ctx, *req.RoleID)
|
||||
if err != nil {
|
||||
return nil, errors.New("role not found")
|
||||
@@ -178,7 +158,6 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invalidate cache if role was changed
|
||||
if req.RoleID != nil && s.cacheInvalidator != nil {
|
||||
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
|
||||
}
|
||||
@@ -187,14 +166,12 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// HasPermission checks if a user has a specific permission.
|
||||
func (s *MembershipService) HasPermission(ctx context.Context, userID string, permissionName string) (bool, error) {
|
||||
user, err := s.repo.FindUserByIDWithPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Super admin and Admin have all permissions
|
||||
if user.Role.Name == "Super Admin" || user.Role.Name == "Admin" {
|
||||
return true, nil
|
||||
}
|
||||
@@ -208,15 +185,13 @@ func (s *MembershipService) HasPermission(ctx context.Context, userID string, pe
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// SetupInitialData creates the initial roles and permissions.
|
||||
func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
// Define all permissions
|
||||
permissions := model.AllPermissions()
|
||||
|
||||
createdPermissions := make([]model.Permission, 0)
|
||||
for _, pName := range permissions {
|
||||
perm, err := s.repo.FindPermissionByName(ctx, pName)
|
||||
if err != nil { // Assuming error means not found
|
||||
if err != nil {
|
||||
perm = &model.Permission{Name: pName}
|
||||
if err := s.repo.CreatePermission(ctx, perm); err != nil {
|
||||
return err
|
||||
@@ -225,7 +200,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
createdPermissions = append(createdPermissions, *perm)
|
||||
}
|
||||
|
||||
// Create Super Admin role with all permissions
|
||||
superAdminRole, err := s.repo.FindRoleByName(ctx, "Super Admin")
|
||||
if err != nil {
|
||||
superAdminRole = &model.Role{Name: "Super Admin"}
|
||||
@@ -237,7 +211,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create Admin role with same permissions as Super Admin
|
||||
adminRole, err := s.repo.FindRoleByName(ctx, "Admin")
|
||||
if err != nil {
|
||||
adminRole = &model.Role{Name: "Admin"}
|
||||
@@ -249,7 +222,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create Manager role with limited permissions (excluding membership, role, user, server create/delete)
|
||||
managerRole, err := s.repo.FindRoleByName(ctx, "Manager")
|
||||
if err != nil {
|
||||
managerRole = &model.Role{Name: "Manager"}
|
||||
@@ -258,7 +230,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Define manager permissions (limited set)
|
||||
managerPermissionNames := []string{
|
||||
model.ServerView,
|
||||
model.ServerUpdate,
|
||||
@@ -282,16 +253,14 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate all caches after role setup changes
|
||||
if s.cacheInvalidator != nil {
|
||||
s.cacheInvalidator.InvalidateAllUserPermissions()
|
||||
}
|
||||
|
||||
// Create a default admin user if one doesn't exist
|
||||
_, err = s.repo.FindUserByUsername(ctx, "admin")
|
||||
if err != nil {
|
||||
logging.Debug("Creating default admin user")
|
||||
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed
|
||||
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -300,7 +269,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllRoles retrieves all roles for dropdown selection.
|
||||
func (s *MembershipService) GetAllRoles(ctx context.Context) ([]*model.Role, error) {
|
||||
return s.repo.ListRoles(ctx)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
|
||||
const (
|
||||
DefaultStartPort = 9600
|
||||
RequiredPortCount = 1 // Update this if ACC needs more ports
|
||||
RequiredPortCount = 1
|
||||
)
|
||||
|
||||
type ServerService struct {
|
||||
@@ -45,18 +45,15 @@ type pendingState struct {
|
||||
}
|
||||
|
||||
func (s *ServerService) ensureLogTailing(server *model.Server, instance *tracking.AccServerInstance) {
|
||||
// Check if we already have a tailer
|
||||
if _, exists := s.logTailers.Load(server.ID); exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Start tailing in a goroutine that handles file creation/deletion
|
||||
go func() {
|
||||
logPath := filepath.Join(server.GetLogPath(), "server.log")
|
||||
tailer := tracking.NewLogTailer(logPath, instance.HandleLogLine)
|
||||
s.logTailers.Store(server.ID, tailer)
|
||||
|
||||
// Start tailing and automatically handle file changes
|
||||
tailer.Start()
|
||||
}()
|
||||
}
|
||||
@@ -82,7 +79,6 @@ func NewServerService(
|
||||
webSocketService: webSocketService,
|
||||
}
|
||||
|
||||
// Initialize server instances
|
||||
servers, err := repository.GetAll(context.Background(), &model.ServerFilter{})
|
||||
if err != nil {
|
||||
logging.Error("Failed to get servers: %v", err)
|
||||
@@ -90,7 +86,6 @@ func NewServerService(
|
||||
}
|
||||
|
||||
for i := range *servers {
|
||||
// Initialize instance regardless of status
|
||||
logging.Info("Starting server runtime for server ID: %d", (*servers)[i].ID)
|
||||
service.StartAccServerRuntime(&(*servers)[i])
|
||||
}
|
||||
@@ -99,7 +94,7 @@ func NewServerService(
|
||||
}
|
||||
|
||||
func (s *ServerService) shouldInsertStateHistory(serverID uuid.UUID) bool {
|
||||
insertInterval := 5 * time.Minute // Configure this as needed
|
||||
insertInterval := 5 * time.Minute
|
||||
|
||||
lastInsertInterface, exists := s.lastInsertTimes.Load(serverID)
|
||||
if !exists {
|
||||
@@ -122,16 +117,15 @@ func (s *ServerService) getNextSessionID(serverID uuid.UUID) uuid.UUID {
|
||||
lastID, err := s.stateHistoryRepo.GetLastSessionID(context.Background(), serverID)
|
||||
if err != nil {
|
||||
logging.Error("Failed to get last session ID for server %s: %v", serverID, err)
|
||||
return uuid.New() // Return new UUID as fallback
|
||||
return uuid.New()
|
||||
}
|
||||
if lastID == uuid.Nil {
|
||||
return uuid.New() // Return new UUID if no previous session
|
||||
return uuid.New()
|
||||
}
|
||||
return uuid.New() // Always generate new UUID for each session
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.ServerState) {
|
||||
// Get or create session ID when session changes
|
||||
currentSessionInterface, exists := s.instances.Load(serverID)
|
||||
var sessionID uuid.UUID
|
||||
if !exists {
|
||||
@@ -163,7 +157,6 @@ func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.Serv
|
||||
}
|
||||
|
||||
func (s *ServerService) updateSessionDuration(server *model.Server, sessionType model.TrackSession) {
|
||||
// Get configs using helper methods
|
||||
event, err := s.configService.GetEventConfig(server)
|
||||
if err != nil {
|
||||
event = &model.EventConfig{}
|
||||
@@ -181,9 +174,7 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
|
||||
serverInstance.State.Track = event.Track
|
||||
serverInstance.State.MaxConnections = configuration.MaxConnections.ToInt()
|
||||
|
||||
// Check if session type has changed
|
||||
if serverInstance.State.Session != sessionType {
|
||||
// Get new session ID for the new session
|
||||
sessionID := s.getNextSessionID(server.ID)
|
||||
s.sessionIDs.Store(server.ID, sessionID)
|
||||
}
|
||||
@@ -204,33 +195,27 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
|
||||
}
|
||||
|
||||
func (s *ServerService) GenerateServerPath(server *model.Server) {
|
||||
// Get the base steamcmd path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDDirPath()
|
||||
server.FromSteamCMD = true
|
||||
server.Path = server.GenerateServerPath(steamCMDPath)
|
||||
server.FromSteamCMD = true
|
||||
}
|
||||
|
||||
func (s *ServerService) handleStateChange(server *model.Server, state *model.ServerState) {
|
||||
// Update session duration when session changes
|
||||
s.updateSessionDuration(server, state.Session)
|
||||
|
||||
// Invalidate status cache when server state changes
|
||||
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
|
||||
|
||||
// Cancel existing timer if any
|
||||
if debouncer, exists := s.debouncers.Load(server.ID); exists {
|
||||
pending := debouncer.(*pendingState)
|
||||
pending.timer.Stop()
|
||||
}
|
||||
|
||||
// Create new timer
|
||||
timer := time.NewTimer(5 * time.Minute)
|
||||
s.debouncers.Store(server.ID, &pendingState{
|
||||
timer: timer,
|
||||
state: state,
|
||||
})
|
||||
|
||||
// Start goroutine to handle the delayed insert
|
||||
go func() {
|
||||
<-timer.C
|
||||
if debouncer, exists := s.debouncers.Load(server.ID); exists {
|
||||
@@ -240,14 +225,12 @@ func (s *ServerService) handleStateChange(server *model.Server, state *model.Ser
|
||||
}
|
||||
}()
|
||||
|
||||
// If enough time has passed since last insert, insert immediately
|
||||
if s.shouldInsertStateHistory(server.ID) {
|
||||
s.insertStateHistory(server.ID, state)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) StartAccServerRuntime(server *model.Server) {
|
||||
// Get or create instance
|
||||
instanceInterface, exists := s.instances.Load(server.ID)
|
||||
var instance *tracking.AccServerInstance
|
||||
if !exists {
|
||||
@@ -259,20 +242,14 @@ func (s *ServerService) StartAccServerRuntime(server *model.Server) {
|
||||
instance = instanceInterface.(*tracking.AccServerInstance)
|
||||
}
|
||||
|
||||
// Invalidate config cache for this server before loading new configs
|
||||
serverIDStr := server.ID.String()
|
||||
s.configService.configCache.InvalidateServerCache(serverIDStr)
|
||||
|
||||
s.updateSessionDuration(server, instance.State.Session)
|
||||
|
||||
// Ensure log tailing is running (regardless of server status)
|
||||
s.ensureLogTailing(server, instance)
|
||||
}
|
||||
|
||||
// GetAll
|
||||
// Gets All rows from Server table.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -304,10 +281,6 @@ func (s *ServerService) GetAll(ctx *fiber.Ctx, filter *model.ServerFilter) (*[]m
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// GetById
|
||||
// Gets rows by ID from Server table.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -334,22 +307,16 @@ func (as *ServerService) GetById(ctx *fiber.Ctx, serverID uuid.UUID) (*model.Ser
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// CreateServerAsync starts server creation asynchronously and returns immediately
|
||||
func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) error {
|
||||
// Perform basic validation first
|
||||
if err := server.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate server path
|
||||
s.GenerateServerPath(server)
|
||||
|
||||
// Create a background context that won't be cancelled when the HTTP request ends
|
||||
bgCtx := context.Background()
|
||||
|
||||
// Start the actual creation process in a goroutine
|
||||
go func() {
|
||||
// Create server in background without using fiber.Ctx
|
||||
if err := s.createServerBackground(bgCtx, server); err != nil {
|
||||
logging.Error("Async server creation failed for server %s: %v", server.ID, err)
|
||||
s.webSocketService.BroadcastError(server.ID, "Server creation failed", err.Error())
|
||||
@@ -361,11 +328,9 @@ func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server)
|
||||
}
|
||||
|
||||
func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error {
|
||||
// Broadcast step: validation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepValidation), "")
|
||||
|
||||
// Validate basic server configuration
|
||||
if err := server.Validate(); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed,
|
||||
"", fmt.Sprintf("Validation failed: %v", err))
|
||||
@@ -375,19 +340,15 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted,
|
||||
"Server configuration validated successfully", "")
|
||||
|
||||
// Broadcast step: directory creation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepDirectoryCreation), "")
|
||||
|
||||
// Directory creation is handled within InstallServer, so we mark it as completed
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted,
|
||||
"Server directories prepared", "")
|
||||
|
||||
// Broadcast step: Steam download
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepSteamDownload), "")
|
||||
|
||||
// Install server using SteamCMD with streaming support
|
||||
if err := s.steamService.InstallServerWithWebSocket(ctx.UserContext(), server.Path, &server.ID, s.webSocketService); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed,
|
||||
"", fmt.Sprintf("Steam installation failed: %v", err))
|
||||
@@ -397,11 +358,9 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted,
|
||||
"Server files downloaded successfully", "")
|
||||
|
||||
// Broadcast step: config generation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepConfigGeneration), "")
|
||||
|
||||
// Find available ports for server
|
||||
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
|
||||
if err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
|
||||
@@ -409,10 +368,8 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
return fmt.Errorf("failed to find available ports: %v", err)
|
||||
}
|
||||
|
||||
// Use the first port for both TCP and UDP
|
||||
serverPort := ports[0]
|
||||
|
||||
// Update server configuration with the allocated port
|
||||
if err := s.updateServerPort(server, serverPort); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to update server configuration: %v", err))
|
||||
@@ -422,17 +379,14 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted,
|
||||
fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "")
|
||||
|
||||
// Broadcast step: service creation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepServiceCreation), "")
|
||||
|
||||
// Create Windows service with correct paths
|
||||
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
|
||||
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
|
||||
if err := s.windowsService.CreateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to create Windows service: %v", err))
|
||||
// Cleanup on failure
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to create Windows service: %v", err)
|
||||
}
|
||||
@@ -440,7 +394,6 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted,
|
||||
fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "")
|
||||
|
||||
// Broadcast step: firewall rules
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepFirewallRules), "")
|
||||
|
||||
@@ -450,7 +403,6 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to create firewall rules: %v", err))
|
||||
// Cleanup on failure
|
||||
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to create firewall rules: %v", err)
|
||||
@@ -459,15 +411,12 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted,
|
||||
fmt.Sprintf("Firewall rules created for port %d", serverPort), "")
|
||||
|
||||
// Broadcast step: database save
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepDatabaseSave), "")
|
||||
|
||||
// Insert server into database
|
||||
if err := s.repository.Insert(ctx.UserContext(), server); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to save server to database: %v", err))
|
||||
// Cleanup on failure
|
||||
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
|
||||
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
@@ -477,10 +426,8 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted,
|
||||
"Server saved to database successfully", "")
|
||||
|
||||
// Initialize server runtime
|
||||
s.StartAccServerRuntime(server)
|
||||
|
||||
// Broadcast completion
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted,
|
||||
model.GetStepDescription(model.StepCompleted), "")
|
||||
|
||||
@@ -490,13 +437,10 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// createServerBackground performs server creation in background without fiber.Ctx
|
||||
func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error {
|
||||
// Broadcast step: validation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepValidation), "")
|
||||
|
||||
// Validate basic server configuration (already done in async method, but double-check)
|
||||
if err := server.Validate(); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed,
|
||||
"", fmt.Sprintf("Validation failed: %v", err))
|
||||
@@ -506,20 +450,16 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted,
|
||||
"Server configuration validated successfully", "")
|
||||
|
||||
// Broadcast step: directory creation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepDirectoryCreation), "")
|
||||
|
||||
// Directory creation is handled within InstallServer, so we mark it as completed
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted,
|
||||
"Server directories prepared", "")
|
||||
|
||||
// Broadcast step: Steam download
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepSteamDownload), "")
|
||||
|
||||
// Install server using SteamCMD with streaming support
|
||||
if err := s.steamService.InstallServerWithWebSocket(ctx, server.GetServerPath(), &server.ID, s.webSocketService); err != nil {
|
||||
if err := s.steamService.InstallServerWithWebSocket(ctx, server.Path, &server.ID, s.webSocketService); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed,
|
||||
"", fmt.Sprintf("Steam installation failed: %v", err))
|
||||
return fmt.Errorf("failed to install server: %v", err)
|
||||
@@ -528,11 +468,9 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted,
|
||||
"Server files downloaded successfully", "")
|
||||
|
||||
// Broadcast step: config generation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepConfigGeneration), "")
|
||||
|
||||
// Find available ports for server
|
||||
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
|
||||
if err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
|
||||
@@ -540,10 +478,8 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
return fmt.Errorf("failed to find available ports: %v", err)
|
||||
}
|
||||
|
||||
// Use the first port for both TCP and UDP
|
||||
serverPort := ports[0]
|
||||
|
||||
// Update server configuration with the allocated port
|
||||
if err := s.updateServerPort(server, serverPort); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to update server configuration: %v", err))
|
||||
@@ -553,17 +489,14 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted,
|
||||
fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "")
|
||||
|
||||
// Broadcast step: service creation
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepServiceCreation), "")
|
||||
|
||||
// Create Windows service with correct paths
|
||||
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
|
||||
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
|
||||
if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to create Windows service: %v", err))
|
||||
// Cleanup on failure
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to create Windows service: %v", err)
|
||||
}
|
||||
@@ -571,7 +504,6 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted,
|
||||
fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "")
|
||||
|
||||
// Broadcast step: firewall rules
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepFirewallRules), "")
|
||||
|
||||
@@ -581,7 +513,6 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to create firewall rules: %v", err))
|
||||
// Cleanup on failure
|
||||
s.windowsService.DeleteService(ctx, server.ServiceName)
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to create firewall rules: %v", err)
|
||||
@@ -590,15 +521,12 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted,
|
||||
fmt.Sprintf("Firewall rules created for port %d", serverPort), "")
|
||||
|
||||
// Broadcast step: database save
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress,
|
||||
model.GetStepDescription(model.StepDatabaseSave), "")
|
||||
|
||||
// Insert server into database
|
||||
if err := s.repository.Insert(ctx, server); err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed,
|
||||
"", fmt.Sprintf("Failed to save server to database: %v", err))
|
||||
// Cleanup on failure
|
||||
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
|
||||
s.windowsService.DeleteService(ctx, server.ServiceName)
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
@@ -608,10 +536,8 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted,
|
||||
"Server saved to database successfully", "")
|
||||
|
||||
// Initialize server runtime
|
||||
s.StartAccServerRuntime(server)
|
||||
|
||||
// Broadcast completion
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted,
|
||||
model.GetStepDescription(model.StepCompleted), "")
|
||||
|
||||
@@ -622,18 +548,15 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode
|
||||
}
|
||||
|
||||
func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
|
||||
// Get server details
|
||||
server, err := s.repository.GetByID(ctx.UserContext(), serverID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get server details: %v", err)
|
||||
}
|
||||
|
||||
// Stop and remove Windows service
|
||||
if err := s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName); err != nil {
|
||||
logging.Error("Failed to delete Windows service: %v", err)
|
||||
}
|
||||
|
||||
// Remove firewall rules
|
||||
configuration, err := s.configService.GetConfiguration(server)
|
||||
if err != nil {
|
||||
logging.Error("Failed to get configuration for server %d: %v", server.ID, err)
|
||||
@@ -644,17 +567,14 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
|
||||
logging.Error("Failed to delete firewall rules: %v", err)
|
||||
}
|
||||
|
||||
// Uninstall server files
|
||||
if err := s.steamService.UninstallServer(server.Path); err != nil {
|
||||
logging.Error("Failed to uninstall server: %v", err)
|
||||
}
|
||||
|
||||
// Remove from database
|
||||
if err := s.repository.Delete(ctx.UserContext(), serverID); err != nil {
|
||||
return fmt.Errorf("failed to delete server from database: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup runtime resources
|
||||
if tailer, exists := s.logTailers.Load(server.ID); exists {
|
||||
tailer.(*tracking.LogTailer).Stop()
|
||||
s.logTailers.Delete(server.ID)
|
||||
@@ -664,84 +584,27 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
|
||||
s.debouncers.Delete(server.ID)
|
||||
s.sessionIDs.Delete(server.ID)
|
||||
|
||||
// Invalidate status cache for deleted server
|
||||
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error {
|
||||
// Validate server configuration
|
||||
if err := server.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get existing server details
|
||||
existingServer, err := s.repository.GetByID(ctx.UserContext(), server.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing server details: %v", err)
|
||||
}
|
||||
|
||||
// Update server files if path changed
|
||||
if existingServer.Path != server.Path {
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.Path, &server.ID); err != nil {
|
||||
return fmt.Errorf("failed to install server to new location: %v", err)
|
||||
}
|
||||
// Clean up old installation
|
||||
if err := s.steamService.UninstallServer(existingServer.Path); err != nil {
|
||||
logging.Error("Failed to remove old server installation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update Windows service if necessary
|
||||
if existingServer.ServiceName != server.ServiceName || existingServer.Path != server.Path {
|
||||
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
|
||||
serverWorkingDir := server.GetServerPath()
|
||||
if err := s.windowsService.UpdateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
|
||||
return fmt.Errorf("failed to update Windows service: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update firewall rules if service name changed
|
||||
if existingServer.ServiceName != server.ServiceName {
|
||||
if err := s.configureFirewall(server); err != nil {
|
||||
return fmt.Errorf("failed to update firewall rules: %v", err)
|
||||
}
|
||||
// Invalidate cache for old service name
|
||||
s.apiService.statusCache.InvalidateStatus(existingServer.ServiceName)
|
||||
}
|
||||
|
||||
// Update database record
|
||||
if err := s.repository.Update(ctx.UserContext(), server); err != nil {
|
||||
return fmt.Errorf("failed to update server in database: %v", err)
|
||||
}
|
||||
|
||||
// Restart server runtime
|
||||
s.StartAccServerRuntime(server)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerService) configureFirewall(server *model.Server) error {
|
||||
// Find available ports for the server
|
||||
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find available ports: %v", err)
|
||||
}
|
||||
|
||||
// Use the first port for both TCP and UDP
|
||||
serverPort := ports[0]
|
||||
tcpPorts := []int{serverPort}
|
||||
udpPorts := []int{serverPort}
|
||||
|
||||
logging.Info("Configuring firewall for server %d with port %d", server.ID, serverPort)
|
||||
|
||||
// Configure firewall rules
|
||||
if err := s.firewallService.UpdateServerRules(server.Name, tcpPorts, udpPorts); err != nil {
|
||||
return fmt.Errorf("failed to configure firewall: %v", err)
|
||||
}
|
||||
|
||||
// Update server configuration with the allocated port
|
||||
if err := s.updateServerPort(server, serverPort); err != nil {
|
||||
return fmt.Errorf("failed to update server configuration: %v", err)
|
||||
}
|
||||
@@ -750,7 +613,6 @@ func (s *ServerService) configureFirewall(server *model.Server) error {
|
||||
}
|
||||
|
||||
func (s *ServerService) updateServerPort(server *model.Server, port int) error {
|
||||
// Load current configuration
|
||||
config, err := s.configService.GetConfiguration(server)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load server configuration: %v", err)
|
||||
@@ -759,7 +621,6 @@ func (s *ServerService) updateServerPort(server *model.Server, port int) error {
|
||||
config.TcpPort = model.IntString(port)
|
||||
config.UdpPort = model.IntString(port)
|
||||
|
||||
// Save the updated configuration
|
||||
if err := s.configService.SaveConfiguration(server, config); err != nil {
|
||||
return fmt.Errorf("failed to save server configuration: %v", err)
|
||||
}
|
||||
|
||||
@@ -7,17 +7,12 @@ import (
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// InitializeServices
|
||||
// Initializes Dependency Injection modules for services
|
||||
//
|
||||
// Args:
|
||||
// *dig.Container: Dig Container
|
||||
// *dig.Container: Dig Container
|
||||
func InitializeServices(c *dig.Container) {
|
||||
logging.Debug("Initializing repositories")
|
||||
repository.InitializeRepositories(c)
|
||||
|
||||
logging.Debug("Registering services")
|
||||
// Provide services
|
||||
c.Provide(NewSteamService)
|
||||
c.Provide(NewServerService)
|
||||
c.Provide(NewStateHistoryService)
|
||||
|
||||
@@ -24,9 +24,9 @@ func NewServiceControlService(repository *repository.ServiceControlRepository,
|
||||
repository: repository,
|
||||
serverRepository: serverRepository,
|
||||
statusCache: model.NewServerStatusCache(model.CacheConfig{
|
||||
ExpirationTime: 30 * time.Second, // Cache expires after 30 seconds
|
||||
ThrottleTime: 5 * time.Second, // Minimum 5 seconds between checks
|
||||
DefaultStatus: model.StatusRunning, // Default to running if throttled
|
||||
ExpirationTime: 30 * time.Second,
|
||||
ThrottleTime: 5 * time.Second,
|
||||
DefaultStatus: model.StatusRunning,
|
||||
}),
|
||||
windowsService: NewWindowsService(),
|
||||
}
|
||||
@@ -42,18 +42,15 @@ func (as *ServiceControlService) GetStatus(ctx *fiber.Ctx) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Try to get status from cache
|
||||
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
|
||||
return status.String(), nil
|
||||
}
|
||||
|
||||
// If cache miss or expired, check actual status
|
||||
statusStr, err := as.StatusServer(serviceName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -65,7 +62,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update status cache for this service before starting
|
||||
as.statusCache.UpdateStatus(serviceName, model.StatusStarting)
|
||||
|
||||
_, err = as.StartServer(serviceName)
|
||||
@@ -77,7 +73,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -89,7 +84,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update status cache for this service before stopping
|
||||
as.statusCache.UpdateStatus(serviceName, model.StatusStopping)
|
||||
|
||||
_, err = as.StopServer(serviceName)
|
||||
@@ -101,7 +95,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -113,7 +106,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update status cache for this service before restarting
|
||||
as.statusCache.UpdateStatus(serviceName, model.StatusRestarting)
|
||||
|
||||
_, err = as.RestartServer(serviceName)
|
||||
@@ -125,7 +117,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -135,20 +126,16 @@ func (as *ServiceControlService) StatusServer(serviceName string) (string, error
|
||||
return as.windowsService.Status(context.Background(), serviceName)
|
||||
}
|
||||
|
||||
// GetCachedStatus gets the cached status for a service name without requiring fiber context
|
||||
func (as *ServiceControlService) GetCachedStatus(serviceName string) (string, error) {
|
||||
// Try to get status from cache
|
||||
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
|
||||
return status.String(), nil
|
||||
}
|
||||
|
||||
// If cache miss or expired, check actual status
|
||||
statusStr, err := as.StatusServer(serviceName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type ServiceManager struct {
|
||||
executor *command.CommandExecutor
|
||||
executor *command.CommandExecutor
|
||||
psExecutor *command.CommandExecutor
|
||||
}
|
||||
|
||||
@@ -24,17 +24,14 @@ func NewServiceManager() *ServiceManager {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) ManageService(serviceName, action string) (string, error) {
|
||||
// Run NSSM command through PowerShell to ensure elevation
|
||||
output, err := s.psExecutor.ExecuteWithOutput("-nologo", "-noprofile", ".\\nssm", action, serviceName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Clean up output by removing null bytes and trimming whitespace
|
||||
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
|
||||
// Remove \r\n from status strings
|
||||
cleaned = strings.TrimSuffix(cleaned, "\r\n")
|
||||
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
@@ -51,11 +48,9 @@ func (s *ServiceManager) Stop(serviceName string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) Restart(serviceName string) (string, error) {
|
||||
// First stop the service
|
||||
if _, err := s.Stop(serviceName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Then start it again
|
||||
return s.Start(serviceName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
|
||||
eg, gCtx := errgroup.WithContext(ctx.UserContext())
|
||||
|
||||
// Get Summary Stats (Peak/Avg Players, Total Sessions)
|
||||
eg.Go(func() error {
|
||||
summary, err := s.repository.GetSummaryStats(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -61,7 +60,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Total Playtime
|
||||
eg.Go(func() error {
|
||||
playtime, err := s.repository.GetTotalPlaytime(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -74,7 +72,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Player Count Over Time
|
||||
eg.Go(func() error {
|
||||
playerCount, err := s.repository.GetPlayerCountOverTime(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -87,7 +84,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Session Types
|
||||
eg.Go(func() error {
|
||||
sessionTypes, err := s.repository.GetSessionTypes(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -100,7 +96,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Daily Activity
|
||||
eg.Go(func() error {
|
||||
dailyActivity, err := s.repository.GetDailyActivity(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -113,7 +108,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Recent Sessions
|
||||
eg.Go(func() error {
|
||||
recentSessions, err := s.repository.GetRecentSessions(gCtx, filter)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,12 +22,11 @@ const (
|
||||
)
|
||||
|
||||
type SteamService struct {
|
||||
executor *command.CommandExecutor
|
||||
interactiveExecutor *command.InteractiveCommandExecutor
|
||||
repository *repository.SteamCredentialsRepository
|
||||
tfaManager *model.Steam2FAManager
|
||||
pathValidator *security.PathValidator
|
||||
downloadVerifier *security.DownloadVerifier
|
||||
executor *command.CommandExecutor
|
||||
repository *repository.SteamCredentialsRepository
|
||||
tfaManager *model.Steam2FAManager
|
||||
pathValidator *security.PathValidator
|
||||
downloadVerifier *security.DownloadVerifier
|
||||
}
|
||||
|
||||
func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService {
|
||||
@@ -37,12 +36,11 @@ func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManag
|
||||
}
|
||||
|
||||
return &SteamService{
|
||||
executor: baseExecutor,
|
||||
interactiveExecutor: command.NewInteractiveCommandExecutor(baseExecutor, tfaManager),
|
||||
repository: repository,
|
||||
tfaManager: tfaManager,
|
||||
pathValidator: security.NewPathValidator(),
|
||||
downloadVerifier: security.NewDownloadVerifier(),
|
||||
executor: baseExecutor,
|
||||
repository: repository,
|
||||
tfaManager: tfaManager,
|
||||
pathValidator: security.NewPathValidator(),
|
||||
downloadVerifier: security.NewDownloadVerifier(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,21 +56,17 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr
|
||||
}
|
||||
|
||||
func (s *SteamService) ensureSteamCMD(_ context.Context) error {
|
||||
// Get SteamCMD path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
steamCMDDir := filepath.Dir(steamCMDPath)
|
||||
|
||||
// Check if SteamCMD exists
|
||||
if _, err := os.Stat(steamCMDPath); !os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(steamCMDDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create SteamCMD directory: %v", err)
|
||||
}
|
||||
|
||||
// Download and install SteamCMD securely
|
||||
logging.Info("Downloading SteamCMD...")
|
||||
steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip")
|
||||
if err := s.downloadVerifier.VerifyAndDownload(
|
||||
@@ -82,150 +76,27 @@ func (s *SteamService) ensureSteamCMD(_ context.Context) error {
|
||||
return fmt.Errorf("failed to download SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
// Extract SteamCMD
|
||||
logging.Info("Extracting SteamCMD...")
|
||||
if err := s.executor.Execute("-Command",
|
||||
fmt.Sprintf("Expand-Archive -Path 'steamcmd.zip' -DestinationPath '%s'", steamCMDDir)); err != nil {
|
||||
return fmt.Errorf("failed to extract SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
// Clean up zip file
|
||||
os.Remove("steamcmd.zip")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) InstallServer(ctx context.Context, installPath string, serverID *uuid.UUID) error {
|
||||
if err := s.ensureSteamCMD(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate installation path for security
|
||||
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
|
||||
return fmt.Errorf("invalid installation path: %v", err)
|
||||
}
|
||||
|
||||
// Convert to absolute path and ensure proper Windows path format
|
||||
absPath, err := filepath.Abs(installPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path: %v", err)
|
||||
}
|
||||
absPath = filepath.Clean(absPath)
|
||||
|
||||
// Ensure install path exists
|
||||
if err := os.MkdirAll(absPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create install directory: %v", err)
|
||||
}
|
||||
|
||||
// Get Steam credentials
|
||||
creds, err := s.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get Steam credentials: %v", err)
|
||||
}
|
||||
|
||||
// Get SteamCMD path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
|
||||
// Build SteamCMD command arguments
|
||||
steamCMDArgs := []string{
|
||||
"+force_install_dir", absPath,
|
||||
"+login",
|
||||
}
|
||||
|
||||
if creds != nil && creds.Username != "" {
|
||||
logging.Info("Using Steam credentials for user: %s", creds.Username)
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Username)
|
||||
if creds.Password != "" {
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Password)
|
||||
}
|
||||
} else {
|
||||
logging.Info("Using anonymous Steam login")
|
||||
steamCMDArgs = append(steamCMDArgs, "anonymous")
|
||||
}
|
||||
|
||||
steamCMDArgs = append(steamCMDArgs,
|
||||
"+app_update", ACCServerAppID,
|
||||
"validate",
|
||||
"+quit",
|
||||
)
|
||||
|
||||
// Execute SteamCMD directly without PowerShell wrapper to get better output capture
|
||||
args := steamCMDArgs
|
||||
|
||||
// Use interactive executor to handle potential 2FA prompts with timeout
|
||||
logging.Info("Installing ACC server to %s...", absPath)
|
||||
logging.Info("SteamCMD command: %s %s", steamCMDPath, strings.Join(args, " "))
|
||||
|
||||
// Create a context with timeout to prevent hanging indefinitely
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) // Increased timeout
|
||||
defer cancel()
|
||||
|
||||
// Update the executor to use SteamCMD directly
|
||||
originalExePath := s.interactiveExecutor.ExePath
|
||||
s.interactiveExecutor.ExePath = steamCMDPath
|
||||
defer func() {
|
||||
s.interactiveExecutor.ExePath = originalExePath
|
||||
}()
|
||||
|
||||
if err := s.interactiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
|
||||
logging.Error("SteamCMD execution failed: %v", err)
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
|
||||
}
|
||||
return fmt.Errorf("failed to run SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
logging.Info("SteamCMD execution completed successfully, proceeding with verification...")
|
||||
|
||||
// Add a delay to allow Steam to properly cleanup
|
||||
logging.Info("Waiting for Steam operations to complete...")
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Verify installation
|
||||
exePath := filepath.Join(absPath, "server", "accServer.exe")
|
||||
logging.Info("Checking for ACC server executable at: %s", exePath)
|
||||
|
||||
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
||||
// Log directory contents to help debug
|
||||
logging.Info("accServer.exe not found, checking directory contents...")
|
||||
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
|
||||
logging.Info("Contents of %s:", absPath)
|
||||
for _, entry := range entries {
|
||||
logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir())
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there's a server subdirectory
|
||||
serverDir := filepath.Join(absPath, "server")
|
||||
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
|
||||
logging.Info("Contents of %s:", serverDir)
|
||||
for _, entry := range entries {
|
||||
logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir())
|
||||
}
|
||||
} else {
|
||||
logging.Info("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr)
|
||||
}
|
||||
|
||||
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
|
||||
}
|
||||
|
||||
logging.Info("Server installation completed successfully - accServer.exe found at %s", exePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstallServerWithWebSocket installs a server with WebSocket output streaming
|
||||
func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPath string, serverID *uuid.UUID, wsService *WebSocketService) error {
|
||||
if err := s.ensureSteamCMD(ctx); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate installation path for security
|
||||
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
|
||||
return fmt.Errorf("invalid installation path: %v", err)
|
||||
}
|
||||
|
||||
// Convert to absolute path and ensure proper Windows path format
|
||||
absPath, err := filepath.Abs(installPath)
|
||||
if err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
|
||||
@@ -233,7 +104,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
|
||||
}
|
||||
absPath = filepath.Clean(absPath)
|
||||
|
||||
// Ensure install path exists
|
||||
if err := os.MkdirAll(absPath, 0755); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
|
||||
return fmt.Errorf("failed to create install directory: %v", err)
|
||||
@@ -241,17 +111,14 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
|
||||
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
|
||||
|
||||
// Get Steam credentials
|
||||
creds, err := s.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
|
||||
return fmt.Errorf("failed to get Steam credentials: %v", err)
|
||||
}
|
||||
|
||||
// Get SteamCMD path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
|
||||
// Build SteamCMD command arguments
|
||||
steamCMDArgs := []string{
|
||||
"+force_install_dir", absPath,
|
||||
"+login",
|
||||
@@ -274,27 +141,32 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
|
||||
"+quit",
|
||||
)
|
||||
|
||||
// Execute SteamCMD with WebSocket output streaming
|
||||
args := steamCMDArgs
|
||||
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
|
||||
|
||||
// Create a context with timeout to prevent hanging indefinitely
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Update the executor to use SteamCMD directly
|
||||
originalExePath := s.interactiveExecutor.ExePath
|
||||
s.interactiveExecutor.ExePath = steamCMDPath
|
||||
defer func() {
|
||||
s.interactiveExecutor.ExePath = originalExePath
|
||||
}()
|
||||
callbackConfig := &command.CallbackConfig{
|
||||
OnOutput: func(serverID uuid.UUID, output string, isError bool) {
|
||||
wsService.BroadcastSteamOutput(serverID, output, isError)
|
||||
},
|
||||
OnCommand: func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) {
|
||||
if completed {
|
||||
if success {
|
||||
wsService.BroadcastSteamOutput(serverID, "Command completed successfully", false)
|
||||
} else {
|
||||
wsService.BroadcastSteamOutput(serverID, fmt.Sprintf("Command failed: %s", error), true)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Create a modified interactive executor that streams output to WebSocket
|
||||
wsInteractiveExecutor := command.NewInteractiveCommandExecutorWithWebSocket(s.executor, s.tfaManager, wsService, *serverID)
|
||||
wsInteractiveExecutor.ExePath = steamCMDPath
|
||||
callbackInteractiveExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbackConfig, *serverID)
|
||||
callbackInteractiveExecutor.ExePath = steamCMDPath
|
||||
|
||||
if err := wsInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
|
||||
if err := callbackInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
|
||||
@@ -304,11 +176,9 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
|
||||
|
||||
wsService.BroadcastSteamOutput(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
|
||||
|
||||
// Add a delay to allow Steam to properly cleanup
|
||||
wsService.BroadcastSteamOutput(*serverID, "Waiting for Steam operations to complete...", false)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Verify installation
|
||||
exePath := filepath.Join(absPath, "server", "accServer.exe")
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
|
||||
|
||||
@@ -322,7 +192,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there's a server subdirectory
|
||||
serverDir := filepath.Join(absPath, "server")
|
||||
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
|
||||
@@ -341,8 +210,117 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) UpdateServer(ctx context.Context, installPath string, serverID *uuid.UUID) error {
|
||||
return s.InstallServer(ctx, installPath, serverID) // Same process as install
|
||||
func (s *SteamService) InstallServerWithCallbacks(ctx context.Context, installPath string, serverID *uuid.UUID, outputCallback command.OutputCallback) error {
|
||||
if err := s.ensureSteamCMD(ctx); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
|
||||
return fmt.Errorf("invalid installation path: %v", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(installPath)
|
||||
if err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
|
||||
return fmt.Errorf("failed to get absolute path: %v", err)
|
||||
}
|
||||
absPath = filepath.Clean(absPath)
|
||||
|
||||
if err := os.MkdirAll(absPath, 0755); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
|
||||
return fmt.Errorf("failed to create install directory: %v", err)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
|
||||
|
||||
creds, err := s.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
|
||||
return fmt.Errorf("failed to get Steam credentials: %v", err)
|
||||
}
|
||||
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
|
||||
steamCMDArgs := []string{
|
||||
"+force_install_dir", absPath,
|
||||
"+login",
|
||||
}
|
||||
|
||||
if creds != nil && creds.Username != "" {
|
||||
outputCallback(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false)
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Username)
|
||||
if creds.Password != "" {
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Password)
|
||||
}
|
||||
} else {
|
||||
outputCallback(*serverID, "Using anonymous Steam login", false)
|
||||
steamCMDArgs = append(steamCMDArgs, "anonymous")
|
||||
}
|
||||
|
||||
steamCMDArgs = append(steamCMDArgs,
|
||||
"+app_update", ACCServerAppID,
|
||||
"validate",
|
||||
"+quit",
|
||||
)
|
||||
|
||||
args := steamCMDArgs
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
callbacks := &command.CallbackConfig{
|
||||
OnOutput: outputCallback,
|
||||
}
|
||||
|
||||
callbackExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbacks, *serverID)
|
||||
callbackExecutor.ExePath = steamCMDPath
|
||||
|
||||
if err := callbackExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
|
||||
}
|
||||
return fmt.Errorf("failed to run SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
|
||||
|
||||
outputCallback(*serverID, "Waiting for Steam operations to complete...", false)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
exePath := filepath.Join(absPath, "server", "accServer.exe")
|
||||
outputCallback(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
|
||||
|
||||
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
||||
outputCallback(*serverID, "accServer.exe not found, checking directory contents...", false)
|
||||
|
||||
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", absPath), false)
|
||||
for _, entry := range entries {
|
||||
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
|
||||
}
|
||||
}
|
||||
|
||||
serverDir := filepath.Join(absPath, "server")
|
||||
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
|
||||
for _, entry := range entries {
|
||||
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
|
||||
}
|
||||
} else {
|
||||
outputCallback(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true)
|
||||
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) UninstallServer(installPath string) error {
|
||||
|
||||
@@ -11,25 +11,21 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// WebSocketConnection represents a single WebSocket connection
|
||||
type WebSocketConnection struct {
|
||||
conn *websocket.Conn
|
||||
serverID *uuid.UUID // If connected to a specific server creation process
|
||||
userID *uuid.UUID // User who owns this connection
|
||||
serverID *uuid.UUID
|
||||
userID *uuid.UUID
|
||||
}
|
||||
|
||||
// WebSocketService manages WebSocket connections and message broadcasting
|
||||
type WebSocketService struct {
|
||||
connections sync.Map // map[string]*WebSocketConnection - key is connection ID
|
||||
connections sync.Map
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewWebSocketService creates a new WebSocket service
|
||||
func NewWebSocketService() *WebSocketService {
|
||||
return &WebSocketService{}
|
||||
}
|
||||
|
||||
// AddConnection adds a new WebSocket connection
|
||||
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
|
||||
wsConn := &WebSocketConnection{
|
||||
conn: conn,
|
||||
@@ -39,7 +35,6 @@ func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, u
|
||||
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
|
||||
}
|
||||
|
||||
// RemoveConnection removes a WebSocket connection
|
||||
func (ws *WebSocketService) RemoveConnection(connID string) {
|
||||
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
@@ -49,7 +44,6 @@ func (ws *WebSocketService) RemoveConnection(connID string) {
|
||||
logging.Info("WebSocket connection removed: %s", connID)
|
||||
}
|
||||
|
||||
// SetServerID associates a connection with a specific server creation process
|
||||
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
|
||||
if conn, exists := ws.connections.Load(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
@@ -58,7 +52,6 @@ func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastStep sends a step update to all connections associated with a server
|
||||
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
|
||||
stepMsg := model.StepMessage{
|
||||
Step: step,
|
||||
@@ -77,7 +70,6 @@ func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerC
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
// BroadcastSteamOutput sends Steam command output to all connections associated with a server
|
||||
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
|
||||
steamMsg := model.SteamOutputMessage{
|
||||
Output: output,
|
||||
@@ -94,7 +86,6 @@ func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output stri
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
// BroadcastError sends an error message to all connections associated with a server
|
||||
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
|
||||
errorMsg := model.ErrorMessage{
|
||||
Error: error,
|
||||
@@ -111,7 +102,6 @@ func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, det
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
// BroadcastComplete sends a completion message to all connections associated with a server
|
||||
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
|
||||
completeMsg := model.CompleteMessage{
|
||||
ServerID: serverID,
|
||||
@@ -129,7 +119,6 @@ func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool,
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
// broadcastToServer sends a message to all connections associated with a specific server
|
||||
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
@@ -137,22 +126,35 @@ func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.
|
||||
return
|
||||
}
|
||||
|
||||
sentToAssociatedConnections := false
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
// Send to connections associated with this server
|
||||
if wsConn.serverID != nil && *wsConn.serverID == serverID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
// Remove the connection if it's broken
|
||||
ws.RemoveConnection(key.(string))
|
||||
} else {
|
||||
sentToAssociatedConnections = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !sentToAssociatedConnections && (message.Type == model.MessageTypeStep || message.Type == model.MessageTypeError || message.Type == model.MessageTypeComplete) {
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastToUser sends a message to all connections owned by a specific user
|
||||
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
@@ -162,11 +164,9 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
// Send to connections owned by this user
|
||||
if wsConn.userID != nil && *wsConn.userID == userID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
// Remove the connection if it's broken
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
@@ -175,7 +175,6 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS
|
||||
})
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the count of active connections
|
||||
func (ws *WebSocketService) GetActiveConnections() int {
|
||||
count := 0
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
|
||||
@@ -23,34 +23,25 @@ func NewWindowsService() *WindowsService {
|
||||
}
|
||||
}
|
||||
|
||||
// executeNSSM runs an NSSM command through PowerShell with elevation
|
||||
func (s *WindowsService) ExecuteNSSM(ctx context.Context, args ...string) (string, error) {
|
||||
// Get NSSM path from environment variable
|
||||
nssmPath := env.GetNSSMPath()
|
||||
|
||||
// Prepend NSSM path to arguments
|
||||
nssmArgs := append([]string{"-NoProfile", "-NonInteractive", "-Command", "& " + nssmPath}, args...)
|
||||
|
||||
output, err := s.executor.ExecuteWithOutput(nssmArgs...)
|
||||
if err != nil {
|
||||
// Log the full command and error for debugging
|
||||
logging.Error("NSSM command failed: powershell %s", strings.Join(nssmArgs, " "))
|
||||
logging.Error("NSSM error output: %s", output)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Clean up output by removing null bytes and trimming whitespace
|
||||
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
|
||||
// Remove \r\n from status strings
|
||||
cleaned = strings.TrimSuffix(cleaned, "\r\n")
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
// Service Installation/Configuration Methods
|
||||
|
||||
func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
|
||||
// Ensure paths are absolute and properly formatted for Windows
|
||||
absExecPath, err := filepath.Abs(execPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path for executable: %v", err)
|
||||
@@ -63,30 +54,24 @@ func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPat
|
||||
}
|
||||
absWorkingDir = filepath.Clean(absWorkingDir)
|
||||
|
||||
// Log the paths being used
|
||||
logging.Info("Creating service '%s' with:", serviceName)
|
||||
logging.Info(" Executable: %s", absExecPath)
|
||||
logging.Info(" Working Directory: %s", absWorkingDir)
|
||||
|
||||
// First remove any existing service with the same name
|
||||
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
|
||||
|
||||
// Install service
|
||||
if _, err := s.ExecuteNSSM(ctx, "install", serviceName, absExecPath); err != nil {
|
||||
return fmt.Errorf("failed to install service: %v", err)
|
||||
}
|
||||
|
||||
// Set arguments if provided
|
||||
if len(args) > 0 {
|
||||
cmdArgs := append([]string{"set", serviceName, "AppParameters"}, args...)
|
||||
if _, err := s.ExecuteNSSM(ctx, cmdArgs...); err != nil {
|
||||
// Try to clean up on failure
|
||||
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
|
||||
return fmt.Errorf("failed to set arguments: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify service was created
|
||||
if _, err := s.ExecuteNSSM(ctx, "get", serviceName, "Application"); err != nil {
|
||||
return fmt.Errorf("service creation verification failed: %v", err)
|
||||
}
|
||||
@@ -105,17 +90,13 @@ func (s *WindowsService) DeleteService(ctx context.Context, serviceName string)
|
||||
}
|
||||
|
||||
func (s *WindowsService) UpdateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
|
||||
// First remove the existing service
|
||||
if err := s.DeleteService(ctx, serviceName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then create it again with new parameters
|
||||
return s.CreateService(ctx, serviceName, execPath, workingDir, args)
|
||||
}
|
||||
|
||||
// Service Control Methods
|
||||
|
||||
func (s *WindowsService) Status(ctx context.Context, serviceName string) (string, error) {
|
||||
return s.ExecuteNSSM(ctx, "status", serviceName)
|
||||
}
|
||||
@@ -129,11 +110,9 @@ func (s *WindowsService) Stop(ctx context.Context, serviceName string) (string,
|
||||
}
|
||||
|
||||
func (s *WindowsService) Restart(ctx context.Context, serviceName string) (string, error) {
|
||||
// First stop the service
|
||||
if _, err := s.Stop(ctx, serviceName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Then start it again
|
||||
return s.Start(ctx, serviceName)
|
||||
}
|
||||
|
||||
12
local/utl/cache/cache.go
vendored
12
local/utl/cache/cache.go
vendored
@@ -9,26 +9,22 @@ import (
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
Expiration int64
|
||||
}
|
||||
|
||||
// InMemoryCache is a thread-safe in-memory cache
|
||||
type InMemoryCache struct {
|
||||
items map[string]CacheItem
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInMemoryCache creates and returns a new InMemoryCache instance
|
||||
func NewInMemoryCache() *InMemoryCache {
|
||||
return &InMemoryCache{
|
||||
items: make(map[string]CacheItem),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache with an expiration duration (in seconds)
|
||||
func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -44,7 +40,6 @@ func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duratio
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *InMemoryCache) Get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
@@ -55,24 +50,18 @@ func (c *InMemoryCache) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
|
||||
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
|
||||
// Item has expired, but don't delete here to avoid lock upgrade.
|
||||
// It will be overwritten on the next Set.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
func (c *InMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves an item from the cache. If the item is not found, it
|
||||
// calls the provided function to get the value, sets it in the cache, and
|
||||
// returns it.
|
||||
func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) {
|
||||
if cached, found := c.Get(key); found {
|
||||
if value, ok := cached.(T); ok {
|
||||
@@ -90,7 +79,6 @@ func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetch
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Start initializes the cache and provides it to the DI container.
|
||||
func Start(di *dig.Container) {
|
||||
cache := NewInMemoryCache()
|
||||
err := di.Provide(func() *InMemoryCache {
|
||||
|
||||
245
local/utl/command/callback_executor.go
Normal file
245
local/utl/command/callback_executor.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type CallbackInteractiveCommandExecutor struct {
|
||||
*InteractiveCommandExecutor
|
||||
callbacks *CallbackConfig
|
||||
serverID uuid.UUID
|
||||
}
|
||||
|
||||
func NewCallbackInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, callbacks *CallbackConfig, serverID uuid.UUID) *CallbackInteractiveCommandExecutor {
|
||||
if callbacks == nil {
|
||||
callbacks = DefaultCallbackConfig()
|
||||
}
|
||||
|
||||
return &CallbackInteractiveCommandExecutor{
|
||||
InteractiveCommandExecutor: &InteractiveCommandExecutor{
|
||||
CommandExecutor: baseExecutor,
|
||||
tfaManager: tfaManager,
|
||||
},
|
||||
callbacks: callbacks,
|
||||
serverID: serverID,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *CallbackInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, e.ExePath, args...)
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %v", err)
|
||||
}
|
||||
defer stdin.Close()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
defer stderr.Close()
|
||||
|
||||
logging.Info("Executing interactive command with callbacks: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
|
||||
e.callbacks.OnCommand(e.serverID, e.ExePath, args, false, false, "")
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to start command: %v", err), true)
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
}
|
||||
|
||||
outputDone := make(chan error, 1)
|
||||
cmdDone := make(chan error, 1)
|
||||
|
||||
go e.monitorOutputWithCallbacks(ctx, stdout, stderr, serverID, outputDone)
|
||||
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
var cmdErr, outputErr error
|
||||
completedCount := 0
|
||||
|
||||
for completedCount < 2 {
|
||||
select {
|
||||
case cmdErr = <-cmdDone:
|
||||
completedCount++
|
||||
logging.Info("Command execution completed")
|
||||
e.callbacks.OnOutput(e.serverID, "Command execution completed", false)
|
||||
case outputErr = <-outputDone:
|
||||
completedCount++
|
||||
logging.Info("Output monitoring completed")
|
||||
case <-ctx.Done():
|
||||
e.callbacks.OnOutput(e.serverID, "Command execution cancelled", true)
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if outputErr != nil {
|
||||
logging.Warn("Output monitoring error: %v", outputErr)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Output monitoring error: %v", outputErr), true)
|
||||
}
|
||||
|
||||
success := cmdErr == nil
|
||||
errorMsg := ""
|
||||
if cmdErr != nil {
|
||||
errorMsg = cmdErr.Error()
|
||||
}
|
||||
e.callbacks.OnCommand(e.serverID, e.ExePath, args, true, success, errorMsg)
|
||||
|
||||
return cmdErr
|
||||
}
|
||||
|
||||
func (e *CallbackInteractiveCommandExecutor) monitorOutputWithCallbacks(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
|
||||
defer func() {
|
||||
select {
|
||||
case done <- nil:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
stdoutScanner := bufio.NewScanner(stdout)
|
||||
stderrScanner := bufio.NewScanner(stderr)
|
||||
|
||||
outputChan := make(chan outputLine, 100)
|
||||
readersDone := make(chan struct{}, 2)
|
||||
|
||||
steamConsoleStarted := false
|
||||
tfaRequestCreated := false
|
||||
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stdoutScanner.Scan() {
|
||||
line := stdoutScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDOUT: %s", line)
|
||||
}
|
||||
e.callbacks.OnOutput(e.serverID, line, false)
|
||||
|
||||
select {
|
||||
case outputChan <- outputLine{text: line, isError: false}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stdoutScanner.Err(); err != nil {
|
||||
logging.Warn("Stdout scanner error: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stdout scanner error: %v", err), true)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stderrScanner.Scan() {
|
||||
line := stderrScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDERR: %s", line)
|
||||
}
|
||||
e.callbacks.OnOutput(e.serverID, line, true)
|
||||
|
||||
select {
|
||||
case outputChan <- outputLine{text: line, isError: true}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stderrScanner.Err(); err != nil {
|
||||
logging.Warn("Stderr scanner error: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stderr scanner error: %v", err), true)
|
||||
}
|
||||
}()
|
||||
|
||||
readersFinished := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
return
|
||||
case <-readersDone:
|
||||
readersFinished++
|
||||
if readersFinished == 2 {
|
||||
close(outputChan)
|
||||
for lineData := range outputChan {
|
||||
if e.is2FAPrompt(lineData.text) {
|
||||
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
case lineData, ok := <-outputChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
lowerLine := strings.ToLower(lineData.text)
|
||||
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
|
||||
steamConsoleStarted = true
|
||||
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
|
||||
e.callbacks.OnOutput(e.serverID, "Steam Console Client startup detected", false)
|
||||
}
|
||||
|
||||
if e.is2FAPrompt(lineData.text) {
|
||||
if !tfaRequestCreated {
|
||||
e.callbacks.OnOutput(e.serverID, "2FA prompt detected - waiting for user confirmation", false)
|
||||
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
|
||||
if tfaRequestCreated && e.isSteamContinuing(lineData.text) {
|
||||
logging.Info("Steam CMD appears to have continued after 2FA confirmation")
|
||||
e.callbacks.OnOutput(e.serverID, "Steam CMD continued after 2FA confirmation", false)
|
||||
e.autoCompletePendingRequests(serverID)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
if steamConsoleStarted && !tfaRequestCreated {
|
||||
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
|
||||
e.callbacks.OnOutput(e.serverID, "Waiting for Steam Guard 2FA confirmation...", false)
|
||||
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
|
||||
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type outputLine struct {
|
||||
text string
|
||||
isError bool
|
||||
}
|
||||
19
local/utl/command/callbacks.go
Normal file
19
local/utl/command/callbacks.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package command
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type OutputCallback func(serverID uuid.UUID, output string, isError bool)
|
||||
|
||||
type CommandCallback func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string)
|
||||
|
||||
type CallbackConfig struct {
|
||||
OnOutput OutputCallback
|
||||
OnCommand CommandCallback
|
||||
}
|
||||
|
||||
func DefaultCallbackConfig() *CallbackConfig {
|
||||
return &CallbackConfig{
|
||||
OnOutput: func(uuid.UUID, string, bool) {},
|
||||
OnCommand: func(uuid.UUID, string, []string, bool, bool, string) {},
|
||||
}
|
||||
}
|
||||
@@ -8,17 +8,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CommandExecutor provides a base structure for executing commands
|
||||
type CommandExecutor struct {
|
||||
// Base executable path
|
||||
ExePath string
|
||||
// Working directory for commands
|
||||
WorkDir string
|
||||
// Whether to capture and log output
|
||||
ExePath string
|
||||
WorkDir string
|
||||
LogOutput bool
|
||||
}
|
||||
|
||||
// CommandBuilder helps build command arguments
|
||||
type CommandBuilder struct {
|
||||
args []string
|
||||
}
|
||||
@@ -48,10 +43,9 @@ func (b *CommandBuilder) Build() []string {
|
||||
return b.args
|
||||
}
|
||||
|
||||
// Execute runs a command with the given arguments
|
||||
func (e *CommandExecutor) Execute(args ...string) error {
|
||||
cmd := exec.Command(e.ExePath, args...)
|
||||
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
@@ -65,15 +59,13 @@ func (e *CommandExecutor) Execute(args ...string) error {
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ExecuteWithBuilder runs a command using a CommandBuilder
|
||||
func (e *CommandExecutor) ExecuteWithBuilder(builder *CommandBuilder) error {
|
||||
return e.Execute(builder.Build()...)
|
||||
}
|
||||
|
||||
// ExecuteWithOutput runs a command and returns its output
|
||||
func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
|
||||
cmd := exec.Command(e.ExePath, args...)
|
||||
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
@@ -83,10 +75,9 @@ func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
|
||||
return string(output), err
|
||||
}
|
||||
|
||||
// ExecuteWithEnv runs a command with custom environment variables
|
||||
func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
|
||||
cmd := exec.Command(e.ExePath, args...)
|
||||
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
@@ -100,4 +91,4 @@ func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
|
||||
|
||||
logging.Info("Executing command: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
return cmd.Run()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,14 +9,12 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// InteractiveCommandExecutor extends CommandExecutor to handle interactive commands
|
||||
type InteractiveCommandExecutor struct {
|
||||
*CommandExecutor
|
||||
tfaManager *model.Steam2FAManager
|
||||
@@ -29,7 +27,6 @@ func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *mo
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInteractive runs a command that may require 2FA input
|
||||
func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, e.ExePath, args...)
|
||||
|
||||
@@ -37,7 +34,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
|
||||
// Create pipes for stdin, stdout, and stderr
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %v", err)
|
||||
@@ -58,7 +54,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser
|
||||
|
||||
logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
|
||||
// Enable debug mode if environment variable is set
|
||||
debugMode := os.Getenv("STEAMCMD_DEBUG") == "true"
|
||||
if debugMode {
|
||||
logging.Info("STEAMCMD_DEBUG mode enabled - will log all output and create proactive 2FA requests")
|
||||
@@ -68,19 +63,15 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
}
|
||||
|
||||
// Create channels for output monitoring
|
||||
outputDone := make(chan error, 1)
|
||||
cmdDone := make(chan error, 1)
|
||||
|
||||
// Monitor stdout and stderr for 2FA prompts
|
||||
go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone)
|
||||
|
||||
// Wait for the command to finish in a separate goroutine
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
// Wait for both command and output monitoring to complete
|
||||
var cmdErr, outputErr error
|
||||
completedCount := 0
|
||||
|
||||
@@ -112,18 +103,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
}
|
||||
}()
|
||||
|
||||
// Create scanners for both outputs
|
||||
stdoutScanner := bufio.NewScanner(stdout)
|
||||
stderrScanner := bufio.NewScanner(stderr)
|
||||
|
||||
outputChan := make(chan string, 100) // Buffered channel to prevent blocking
|
||||
outputChan := make(chan string, 100)
|
||||
readersDone := make(chan struct{}, 2)
|
||||
|
||||
// Track Steam Console startup for this specific execution
|
||||
steamConsoleStarted := false
|
||||
tfaRequestCreated := false
|
||||
|
||||
// Read from stdout
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stdoutScanner.Scan() {
|
||||
@@ -131,7 +119,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
if e.LogOutput {
|
||||
logging.Info("STDOUT: %s", line)
|
||||
}
|
||||
// Always log Steam CMD output for debugging 2FA issues
|
||||
if strings.Contains(strings.ToLower(line), "steam") {
|
||||
logging.Info("STEAM_DEBUG: %s", line)
|
||||
}
|
||||
@@ -146,7 +133,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
}
|
||||
}()
|
||||
|
||||
// Read from stderr
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stderrScanner.Scan() {
|
||||
@@ -154,7 +140,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
if e.LogOutput {
|
||||
logging.Info("STDERR: %s", line)
|
||||
}
|
||||
// Always log Steam CMD errors for debugging 2FA issues
|
||||
if strings.Contains(strings.ToLower(line), "steam") {
|
||||
logging.Info("STEAM_DEBUG_ERR: %s", line)
|
||||
}
|
||||
@@ -169,7 +154,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor for completion and 2FA prompts
|
||||
readersFinished := 0
|
||||
for {
|
||||
select {
|
||||
@@ -179,9 +163,7 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
case <-readersDone:
|
||||
readersFinished++
|
||||
if readersFinished == 2 {
|
||||
// Both readers are done, close output channel and finish monitoring
|
||||
close(outputChan)
|
||||
// Drain any remaining output
|
||||
for line := range outputChan {
|
||||
if e.is2FAPrompt(line) {
|
||||
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
|
||||
@@ -195,18 +177,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
}
|
||||
case line, ok := <-outputChan:
|
||||
if !ok {
|
||||
// Channel closed, we're done
|
||||
return
|
||||
}
|
||||
|
||||
// Check for Steam Console startup
|
||||
lowerLine := strings.ToLower(line)
|
||||
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
|
||||
steamConsoleStarted = true
|
||||
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
|
||||
}
|
||||
|
||||
// Check if this line indicates a 2FA prompt
|
||||
if e.is2FAPrompt(line) {
|
||||
if !tfaRequestCreated {
|
||||
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
|
||||
@@ -218,15 +197,11 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if Steam CMD continued after 2FA (auto-completion)
|
||||
if tfaRequestCreated && e.isSteamContinuing(line) {
|
||||
logging.Info("Steam CMD appears to have continued after 2FA confirmation - auto-completing 2FA request")
|
||||
// Auto-complete any pending 2FA requests for this server
|
||||
e.autoCompletePendingRequests(serverID)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
// If Steam Console has started and we haven't seen output for 15 seconds,
|
||||
// it's very likely waiting for 2FA confirmation
|
||||
if steamConsoleStarted && !tfaRequestCreated {
|
||||
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
|
||||
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
|
||||
@@ -243,7 +218,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout,
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
|
||||
// Common SteamCMD 2FA prompts - updated with more comprehensive patterns
|
||||
twoFAKeywords := []string{
|
||||
"please enter your steam guard code",
|
||||
"steam guard",
|
||||
@@ -272,7 +246,6 @@ func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for patterns that might indicate Steam is waiting for input
|
||||
waitingPatterns := []string{
|
||||
"waiting for",
|
||||
"please enter",
|
||||
@@ -330,12 +303,9 @@ func (e *InteractiveCommandExecutor) autoCompletePendingRequests(serverID *uuid.
|
||||
func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error {
|
||||
logging.Info("2FA prompt detected: %s", promptLine)
|
||||
|
||||
// Create a 2FA request
|
||||
request := e.tfaManager.CreateRequest(promptLine, serverID)
|
||||
logging.Info("Created 2FA request with ID: %s", request.ID)
|
||||
|
||||
// Wait for user to complete the 2FA process
|
||||
// Use a reasonable timeout (e.g., 5 minutes)
|
||||
timeout := 5 * time.Minute
|
||||
success, err := e.tfaManager.WaitForCompletion(request.ID, timeout)
|
||||
|
||||
@@ -352,271 +322,3 @@ func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLi
|
||||
logging.Info("2FA completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebSocketInteractiveCommandExecutor extends InteractiveCommandExecutor to stream output via WebSocket
|
||||
type WebSocketInteractiveCommandExecutor struct {
|
||||
*InteractiveCommandExecutor
|
||||
wsService interface{} // Using interface{} to avoid circular import
|
||||
serverID uuid.UUID
|
||||
}
|
||||
|
||||
// NewInteractiveCommandExecutorWithWebSocket creates a new WebSocket-enabled interactive command executor
|
||||
func NewInteractiveCommandExecutorWithWebSocket(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, wsService interface{}, serverID uuid.UUID) *WebSocketInteractiveCommandExecutor {
|
||||
return &WebSocketInteractiveCommandExecutor{
|
||||
InteractiveCommandExecutor: &InteractiveCommandExecutor{
|
||||
CommandExecutor: baseExecutor,
|
||||
tfaManager: tfaManager,
|
||||
},
|
||||
wsService: wsService,
|
||||
serverID: serverID,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInteractive runs a command with WebSocket output streaming
|
||||
func (e *WebSocketInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, e.ExePath, args...)
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
|
||||
// Create pipes for stdin, stdout, and stderr
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %v", err)
|
||||
}
|
||||
defer stdin.Close()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
defer stderr.Close()
|
||||
|
||||
logging.Info("Executing interactive command with WebSocket streaming: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
|
||||
// Broadcast command start via WebSocket
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Failed to start command: %v", err), true)
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
}
|
||||
|
||||
// Create channels for output monitoring
|
||||
outputDone := make(chan error, 1)
|
||||
cmdDone := make(chan error, 1)
|
||||
|
||||
// Monitor stdout and stderr for 2FA prompts with WebSocket streaming
|
||||
go e.monitorOutputWithWebSocket(ctx, stdout, stderr, serverID, outputDone)
|
||||
|
||||
// Wait for the command to finish in a separate goroutine
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
// Wait for both command and output monitoring to complete
|
||||
var cmdErr, outputErr error
|
||||
completedCount := 0
|
||||
|
||||
for completedCount < 2 {
|
||||
select {
|
||||
case cmdErr = <-cmdDone:
|
||||
completedCount++
|
||||
logging.Info("Command execution completed")
|
||||
e.broadcastSteamOutput("Command execution completed", false)
|
||||
case outputErr = <-outputDone:
|
||||
completedCount++
|
||||
logging.Info("Output monitoring completed")
|
||||
case <-ctx.Done():
|
||||
e.broadcastSteamOutput("Command execution cancelled", true)
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if outputErr != nil {
|
||||
logging.Warn("Output monitoring error: %v", outputErr)
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Output monitoring error: %v", outputErr), true)
|
||||
}
|
||||
|
||||
return cmdErr
|
||||
}
|
||||
|
||||
// broadcastSteamOutput sends output to WebSocket using reflection to avoid circular imports
|
||||
func (e *WebSocketInteractiveCommandExecutor) broadcastSteamOutput(output string, isError bool) {
|
||||
if e.wsService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use reflection to call BroadcastSteamOutput method
|
||||
wsServiceVal := reflect.ValueOf(e.wsService)
|
||||
method := wsServiceVal.MethodByName("BroadcastSteamOutput")
|
||||
if !method.IsValid() {
|
||||
logging.Warn("BroadcastSteamOutput method not found on WebSocket service")
|
||||
return
|
||||
}
|
||||
|
||||
// Call the method with parameters: serverID, output, isError
|
||||
args := []reflect.Value{
|
||||
reflect.ValueOf(e.serverID),
|
||||
reflect.ValueOf(output),
|
||||
reflect.ValueOf(isError),
|
||||
}
|
||||
method.Call(args)
|
||||
}
|
||||
|
||||
// monitorOutputWithWebSocket monitors command output and streams it via WebSocket
|
||||
func (e *WebSocketInteractiveCommandExecutor) monitorOutputWithWebSocket(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
|
||||
defer func() {
|
||||
select {
|
||||
case done <- nil:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
// Create scanners for both outputs
|
||||
stdoutScanner := bufio.NewScanner(stdout)
|
||||
stderrScanner := bufio.NewScanner(stderr)
|
||||
|
||||
outputChan := make(chan outputLine, 100) // Buffered channel to prevent blocking
|
||||
readersDone := make(chan struct{}, 2)
|
||||
|
||||
// Track Steam Console startup for this specific execution
|
||||
steamConsoleStarted := false
|
||||
tfaRequestCreated := false
|
||||
|
||||
// Read from stdout
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stdoutScanner.Scan() {
|
||||
line := stdoutScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDOUT: %s", line)
|
||||
}
|
||||
// Stream output via WebSocket
|
||||
e.broadcastSteamOutput(line, false)
|
||||
|
||||
select {
|
||||
case outputChan <- outputLine{text: line, isError: false}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stdoutScanner.Err(); err != nil {
|
||||
logging.Warn("Stdout scanner error: %v", err)
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Stdout scanner error: %v", err), true)
|
||||
}
|
||||
}()
|
||||
|
||||
// Read from stderr
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stderrScanner.Scan() {
|
||||
line := stderrScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDERR: %s", line)
|
||||
}
|
||||
// Stream error output via WebSocket
|
||||
e.broadcastSteamOutput(line, true)
|
||||
|
||||
select {
|
||||
case outputChan <- outputLine{text: line, isError: true}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stderrScanner.Err(); err != nil {
|
||||
logging.Warn("Stderr scanner error: %v", err)
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Stderr scanner error: %v", err), true)
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor for completion and 2FA prompts
|
||||
readersFinished := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
return
|
||||
case <-readersDone:
|
||||
readersFinished++
|
||||
if readersFinished == 2 {
|
||||
// Both readers are done, close output channel and finish monitoring
|
||||
close(outputChan)
|
||||
// Drain any remaining output
|
||||
for lineData := range outputChan {
|
||||
if e.is2FAPrompt(lineData.text) {
|
||||
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
case lineData, ok := <-outputChan:
|
||||
if !ok {
|
||||
// Channel closed, we're done
|
||||
return
|
||||
}
|
||||
|
||||
// Check for Steam Console startup
|
||||
lowerLine := strings.ToLower(lineData.text)
|
||||
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
|
||||
steamConsoleStarted = true
|
||||
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
|
||||
e.broadcastSteamOutput("Steam Console Client startup detected", false)
|
||||
}
|
||||
|
||||
// Check if this line indicates a 2FA prompt
|
||||
if e.is2FAPrompt(lineData.text) {
|
||||
if !tfaRequestCreated {
|
||||
e.broadcastSteamOutput("2FA prompt detected - waiting for user confirmation", false)
|
||||
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if Steam CMD continued after 2FA (auto-completion)
|
||||
if tfaRequestCreated && e.isSteamContinuing(lineData.text) {
|
||||
logging.Info("Steam CMD appears to have continued after 2FA confirmation")
|
||||
e.broadcastSteamOutput("Steam CMD continued after 2FA confirmation", false)
|
||||
// Auto-complete any pending 2FA requests for this server
|
||||
e.autoCompletePendingRequests(serverID)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
// If Steam Console has started and we haven't seen output for 15 seconds,
|
||||
// it's very likely waiting for 2FA confirmation
|
||||
if steamConsoleStarted && !tfaRequestCreated {
|
||||
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
|
||||
e.broadcastSteamOutput("Waiting for Steam Guard 2FA confirmation...", false)
|
||||
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
|
||||
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
|
||||
e.broadcastSteamOutput(fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// outputLine represents a line of output with error status
|
||||
type outputLine struct {
|
||||
text string
|
||||
isError bool
|
||||
}
|
||||
|
||||
@@ -79,12 +79,6 @@ func IndentJson(body []byte) ([]byte, error) {
|
||||
return unmarshaledBody.Bytes(), nil
|
||||
}
|
||||
|
||||
// ParseQueryFilter parses query parameters into a filter struct using reflection.
|
||||
// It supports various field types and uses struct tags to determine parsing behavior.
|
||||
// Supported tags:
|
||||
// - `query:"field_name"` - specifies the query parameter name
|
||||
// - `param:"param_name"` - specifies the path parameter name
|
||||
// - `time_format:"format"` - specifies the time format for parsing dates (default: RFC3339)
|
||||
func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
val := reflect.ValueOf(filter)
|
||||
if val.Kind() != reflect.Ptr || val.IsNil() {
|
||||
@@ -94,14 +88,12 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
elem := val.Elem()
|
||||
typ := elem.Type()
|
||||
|
||||
// Process all fields including embedded structs
|
||||
var processFields func(reflect.Value, reflect.Type) error
|
||||
processFields = func(val reflect.Value, typ reflect.Type) error {
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
fieldType := typ.Field(i)
|
||||
|
||||
// Handle embedded structs recursively
|
||||
if fieldType.Anonymous {
|
||||
if err := processFields(field, fieldType.Type); err != nil {
|
||||
return err
|
||||
@@ -109,12 +101,10 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field cannot be set
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for param tag first (path parameters)
|
||||
if paramName := fieldType.Tag.Get("param"); paramName != "" {
|
||||
if err := parsePathParam(c, field, paramName); err != nil {
|
||||
return fmt.Errorf("error parsing path parameter %s: %v", paramName, err)
|
||||
@@ -122,15 +112,14 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Then check for query tag
|
||||
queryName := fieldType.Tag.Get("query")
|
||||
if queryName == "" {
|
||||
queryName = ToSnakeCase(fieldType.Name) // Default to snake_case of field name
|
||||
queryName = ToSnakeCase(fieldType.Name)
|
||||
}
|
||||
|
||||
queryVal := c.Query(queryName)
|
||||
if queryVal == "" {
|
||||
continue // Skip empty values
|
||||
continue
|
||||
}
|
||||
|
||||
if err := parseValue(field, queryVal, fieldType.Tag); err != nil {
|
||||
|
||||
@@ -18,7 +18,6 @@ var (
|
||||
|
||||
func Init() {
|
||||
godotenv.Load()
|
||||
// Fail fast if critical environment variables are missing
|
||||
Secret = getEnvRequired("APP_SECRET")
|
||||
SecretCode = getEnvRequired("APP_SECRET_CODE")
|
||||
EncryptionKey = getEnvRequired("ENCRYPTION_KEY")
|
||||
@@ -29,7 +28,6 @@ func Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// getEnv retrieves an environment variable or returns a fallback value.
|
||||
func getEnv(key, fallback string) string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return value
|
||||
@@ -38,12 +36,10 @@ func getEnv(key, fallback string) string {
|
||||
return fallback
|
||||
}
|
||||
|
||||
// getEnvRequired retrieves an environment variable and fails if it's not set.
|
||||
// This should be used for critical configuration that must not have defaults.
|
||||
func getEnvRequired(key string) string {
|
||||
if value, exists := os.LookupEnv(key); exists && value != "" {
|
||||
return value
|
||||
}
|
||||
log.Fatalf("Required environment variable %s is not set or is empty", key)
|
||||
return "" // This line will never be reached due to log.Fatalf
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ func Start(di *dig.Container) {
|
||||
func Migrate(db *gorm.DB) {
|
||||
logging.Info("Migrating database")
|
||||
|
||||
// Run GORM AutoMigrate for all models
|
||||
err := db.AutoMigrate(
|
||||
&model.ServiceControlModel{},
|
||||
&model.Config{},
|
||||
@@ -52,7 +51,6 @@ func Migrate(db *gorm.DB) {
|
||||
|
||||
if err != nil {
|
||||
logging.Error("GORM AutoMigrate failed: %v", err)
|
||||
// Don't panic, just log the error as custom migrations may have handled this
|
||||
}
|
||||
|
||||
db.FirstOrCreate(&model.ServiceControlModel{ServiceControl: "Works"})
|
||||
@@ -63,10 +61,8 @@ func Migrate(db *gorm.DB) {
|
||||
func runMigrations(db *gorm.DB) {
|
||||
logging.Info("Running custom database migrations...")
|
||||
|
||||
// Migration 001: Password security upgrade
|
||||
if err := migrations.RunPasswordSecurityMigration(db); err != nil {
|
||||
logging.Error("Failed to run password security migration: %v", err)
|
||||
// Continue - this migration might not be needed for all setups
|
||||
}
|
||||
|
||||
logging.Info("Custom database migrations completed")
|
||||
@@ -132,7 +128,6 @@ func seedCarModels(db *gorm.DB) error {
|
||||
carModels := []model.CarModel{
|
||||
{Value: 0, CarModel: "Porsche 991 GT3 R"},
|
||||
{Value: 1, CarModel: "Mercedes-AMG GT3"},
|
||||
// ... Add all car models from your list
|
||||
}
|
||||
|
||||
for _, cm := range carModels {
|
||||
|
||||
7
local/utl/env/env.go
vendored
7
local/utl/env/env.go
vendored
@@ -6,12 +6,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Default paths for when environment variables are not set
|
||||
DefaultSteamCMDPath = "c:\\steamcmd\\steamcmd.exe"
|
||||
DefaultNSSMPath = ".\\nssm.exe"
|
||||
)
|
||||
|
||||
// GetSteamCMDPath returns the SteamCMD executable path from environment variable or default
|
||||
func GetSteamCMDPath() string {
|
||||
if path := os.Getenv("STEAMCMD_PATH"); path != "" {
|
||||
return path
|
||||
@@ -19,13 +17,11 @@ func GetSteamCMDPath() string {
|
||||
return DefaultSteamCMDPath
|
||||
}
|
||||
|
||||
// GetSteamCMDDirPath returns the directory containing SteamCMD executable
|
||||
func GetSteamCMDDirPath() string {
|
||||
steamCMDPath := GetSteamCMDPath()
|
||||
return filepath.Dir(steamCMDPath)
|
||||
}
|
||||
|
||||
// GetNSSMPath returns the NSSM executable path from environment variable or default
|
||||
func GetNSSMPath() string {
|
||||
if path := os.Getenv("NSSM_PATH"); path != "" {
|
||||
return path
|
||||
@@ -33,17 +29,14 @@ func GetNSSMPath() string {
|
||||
return DefaultNSSMPath
|
||||
}
|
||||
|
||||
// ValidatePaths checks if the configured paths exist (optional validation)
|
||||
func ValidatePaths() map[string]error {
|
||||
errors := make(map[string]error)
|
||||
|
||||
// Check SteamCMD path
|
||||
steamCMDPath := GetSteamCMDPath()
|
||||
if _, err := os.Stat(steamCMDPath); os.IsNotExist(err) {
|
||||
errors["STEAMCMD_PATH"] = err
|
||||
}
|
||||
|
||||
// Check NSSM path
|
||||
nssmPath := GetNSSMPath()
|
||||
if _, err := os.Stat(nssmPath); os.IsNotExist(err) {
|
||||
errors["NSSM_PATH"] = err
|
||||
|
||||
@@ -9,45 +9,37 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// ControllerErrorHandler provides centralized error handling for controllers
|
||||
type ControllerErrorHandler struct {
|
||||
errorLogger *logging.ErrorLogger
|
||||
}
|
||||
|
||||
// NewControllerErrorHandler creates a new controller error handler instance
|
||||
func NewControllerErrorHandler() *ControllerErrorHandler {
|
||||
return &ControllerErrorHandler{
|
||||
errorLogger: logging.GetErrorLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResponse represents a standardized error response
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code int `json:"code,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// HandleError handles controller errors with logging and standardized responses
|
||||
func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get caller information for logging
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
file = strings.TrimPrefix(file, "acc-server-manager/")
|
||||
|
||||
// Build context string
|
||||
contextStr := ""
|
||||
if len(context) > 0 {
|
||||
contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|"))
|
||||
}
|
||||
|
||||
// Clean error message (remove null bytes)
|
||||
cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "")
|
||||
|
||||
// Log the error with context
|
||||
ceh.errorLogger.LogWithContext(
|
||||
fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line),
|
||||
"%s%s",
|
||||
@@ -55,31 +47,26 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
|
||||
cleanErrorMsg,
|
||||
)
|
||||
|
||||
// Create standardized error response
|
||||
errorResponse := ErrorResponse{
|
||||
Error: cleanErrorMsg,
|
||||
Code: statusCode,
|
||||
}
|
||||
|
||||
// Add request details if available
|
||||
if c != nil {
|
||||
if errorResponse.Details == nil {
|
||||
errorResponse.Details = make(map[string]string)
|
||||
}
|
||||
|
||||
// Safely extract request details
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// If any of these panic, just skip adding the details
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
errorResponse.Details["method"] = c.Method()
|
||||
errorResponse.Details["path"] = c.Path()
|
||||
|
||||
// Safely get IP address
|
||||
|
||||
if ip := c.IP(); ip != "" {
|
||||
errorResponse.Details["ip"] = ip
|
||||
} else {
|
||||
@@ -88,14 +75,11 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
|
||||
}()
|
||||
}
|
||||
|
||||
// Return appropriate response based on status code
|
||||
if c == nil {
|
||||
// If context is nil, we can't return a response
|
||||
return fmt.Errorf("cannot return HTTP response: context is nil")
|
||||
}
|
||||
|
||||
|
||||
if statusCode >= 500 {
|
||||
// For server errors, don't expose internal details
|
||||
return c.Status(statusCode).JSON(ErrorResponse{
|
||||
Error: "Internal server error",
|
||||
Code: statusCode,
|
||||
@@ -105,52 +89,42 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
|
||||
return c.Status(statusCode).JSON(errorResponse)
|
||||
}
|
||||
|
||||
// HandleValidationError handles validation errors specifically
|
||||
func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field)
|
||||
}
|
||||
|
||||
// HandleDatabaseError handles database-related errors
|
||||
func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE")
|
||||
}
|
||||
|
||||
// HandleAuthError handles authentication/authorization errors
|
||||
func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH")
|
||||
}
|
||||
|
||||
// HandleNotFoundError handles resource not found errors
|
||||
func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error {
|
||||
err := fmt.Errorf("%s not found", resource)
|
||||
return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND")
|
||||
}
|
||||
|
||||
// HandleBusinessLogicError handles business logic errors
|
||||
func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC")
|
||||
}
|
||||
|
||||
// HandleServiceError handles service layer errors
|
||||
func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE")
|
||||
}
|
||||
|
||||
// HandleParsingError handles request parsing errors
|
||||
func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING")
|
||||
}
|
||||
|
||||
// HandleUUIDError handles UUID parsing errors
|
||||
func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error {
|
||||
err := fmt.Errorf("invalid %s format", field)
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field)
|
||||
}
|
||||
|
||||
// Global controller error handler instance
|
||||
var globalErrorHandler *ControllerErrorHandler
|
||||
|
||||
// GetControllerErrorHandler returns the global controller error handler instance
|
||||
func GetControllerErrorHandler() *ControllerErrorHandler {
|
||||
if globalErrorHandler == nil {
|
||||
globalErrorHandler = NewControllerErrorHandler()
|
||||
@@ -158,49 +132,38 @@ func GetControllerErrorHandler() *ControllerErrorHandler {
|
||||
return globalErrorHandler
|
||||
}
|
||||
|
||||
// Convenience functions using the global error handler
|
||||
|
||||
// HandleError handles controller errors using the global error handler
|
||||
func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
|
||||
return GetControllerErrorHandler().HandleError(c, err, statusCode, context...)
|
||||
}
|
||||
|
||||
// HandleValidationError handles validation errors using the global error handler
|
||||
func HandleValidationError(c *fiber.Ctx, err error, field string) error {
|
||||
return GetControllerErrorHandler().HandleValidationError(c, err, field)
|
||||
}
|
||||
|
||||
// HandleDatabaseError handles database errors using the global error handler
|
||||
func HandleDatabaseError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleDatabaseError(c, err)
|
||||
}
|
||||
|
||||
// HandleAuthError handles auth errors using the global error handler
|
||||
func HandleAuthError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleAuthError(c, err)
|
||||
}
|
||||
|
||||
// HandleNotFoundError handles not found errors using the global error handler
|
||||
func HandleNotFoundError(c *fiber.Ctx, resource string) error {
|
||||
return GetControllerErrorHandler().HandleNotFoundError(c, resource)
|
||||
}
|
||||
|
||||
// HandleBusinessLogicError handles business logic errors using the global error handler
|
||||
func HandleBusinessLogicError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleBusinessLogicError(c, err)
|
||||
}
|
||||
|
||||
// HandleServiceError handles service errors using the global error handler
|
||||
func HandleServiceError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleServiceError(c, err)
|
||||
}
|
||||
|
||||
// HandleParsingError handles parsing errors using the global error handler
|
||||
func HandleParsingError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleParsingError(c, err)
|
||||
}
|
||||
|
||||
// HandleUUIDError handles UUID errors using the global error handler
|
||||
func HandleUUIDError(c *fiber.Ctx, field string) error {
|
||||
return GetControllerErrorHandler().HandleUUIDError(c, field)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// Claims represents the JWT claims.
|
||||
type Claims struct {
|
||||
UserID string `json:"user_id"`
|
||||
IsOpenToken bool `json:"is_open_token"`
|
||||
@@ -27,7 +26,6 @@ type OpenJWTHandler struct {
|
||||
*JWTHandler
|
||||
}
|
||||
|
||||
// NewJWTHandler creates a new JWTHandler instance with the provided secret key.
|
||||
func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
|
||||
jwtHandler := NewJWTHandler(jwtSecret)
|
||||
jwtHandler.IsOpenToken = true
|
||||
@@ -36,7 +34,6 @@ func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// NewJWTHandler creates a new JWTHandler instance with the provided secret key.
|
||||
func NewJWTHandler(jwtSecret string) *JWTHandler {
|
||||
if jwtSecret == "" {
|
||||
errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty")
|
||||
@@ -44,14 +41,12 @@ func NewJWTHandler(jwtSecret string) *JWTHandler {
|
||||
|
||||
var secretKey []byte
|
||||
|
||||
// Decode base64 secret if it looks like base64, otherwise use as-is
|
||||
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
|
||||
secretKey = decoded
|
||||
} else {
|
||||
secretKey = []byte(jwtSecret)
|
||||
}
|
||||
|
||||
// Ensure minimum key length for security
|
||||
if len(secretKey) < 32 {
|
||||
errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security")
|
||||
}
|
||||
@@ -60,8 +55,6 @@ func NewJWTHandler(jwtSecret string) *JWTHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecretKey generates a cryptographically secure random key for JWT signing
|
||||
// This is a utility function for generating new secrets, not used in normal operation
|
||||
func (jh *JWTHandler) GenerateSecretKey() string {
|
||||
key := make([]byte, 64) // 512 bits
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
@@ -70,7 +63,6 @@ func (jh *JWTHandler) GenerateSecretKey() string {
|
||||
return base64.StdEncoding.EncodeToString(key)
|
||||
}
|
||||
|
||||
// GenerateToken generates a new JWT for a given user.
|
||||
func (jh *JWTHandler) GenerateToken(userId string) (string, error) {
|
||||
expirationTime := time.Now().Add(24 * time.Hour)
|
||||
claims := &Claims{
|
||||
@@ -99,7 +91,6 @@ func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time
|
||||
return token.SignedString(jh.SecretKey)
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT and returns the claims if the token is valid.
|
||||
func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ var (
|
||||
timeFormat = "2006-01-02 15:04:05.000"
|
||||
)
|
||||
|
||||
// BaseLogger provides the core logging functionality
|
||||
type BaseLogger struct {
|
||||
file *os.File
|
||||
logger *log.Logger
|
||||
@@ -23,7 +22,6 @@ type BaseLogger struct {
|
||||
initialized bool
|
||||
}
|
||||
|
||||
// LogLevel represents different logging levels
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
@@ -34,28 +32,23 @@ const (
|
||||
LogLevelPanic LogLevel = "PANIC"
|
||||
)
|
||||
|
||||
// Initialize creates a new base logger instance
|
||||
func InitializeBase(tp string) (*BaseLogger, error) {
|
||||
return newBaseLogger(tp)
|
||||
}
|
||||
|
||||
func newBaseLogger(tp string) (*BaseLogger, error) {
|
||||
// Ensure logs directory exists
|
||||
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create logs directory: %v", err)
|
||||
}
|
||||
|
||||
// Open log file with date in name
|
||||
logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp))
|
||||
file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open log file: %v", err)
|
||||
}
|
||||
|
||||
// Create multi-writer for both file and console
|
||||
multiWriter := io.MultiWriter(file, os.Stdout)
|
||||
|
||||
// Create base logger
|
||||
logger := &BaseLogger{
|
||||
file: file,
|
||||
logger: log.New(multiWriter, "", 0),
|
||||
@@ -65,13 +58,11 @@ func newBaseLogger(tp string) (*BaseLogger, error) {
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// GetBaseLogger creates and returns a new base logger instance
|
||||
func GetBaseLogger(tp string) *BaseLogger {
|
||||
baseLogger, _ := InitializeBase(tp)
|
||||
return baseLogger
|
||||
}
|
||||
|
||||
// Close closes the log file
|
||||
func (bl *BaseLogger) Close() error {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
@@ -82,7 +73,6 @@ func (bl *BaseLogger) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Log writes a log entry with the specified level
|
||||
func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
if bl == nil || !bl.initialized {
|
||||
return
|
||||
@@ -91,14 +81,11 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Get caller info (skip 2 frames: this function and the calling Log function)
|
||||
_, file, line, _ := runtime.Caller(2)
|
||||
file = filepath.Base(file)
|
||||
|
||||
// Format message
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Format final log line
|
||||
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
|
||||
time.Now().Format(timeFormat),
|
||||
string(level),
|
||||
@@ -110,7 +97,6 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
bl.logger.Println(logLine)
|
||||
}
|
||||
|
||||
// LogWithCaller writes a log entry with custom caller depth
|
||||
func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) {
|
||||
if bl == nil || !bl.initialized {
|
||||
return
|
||||
@@ -119,14 +105,11 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Get caller info with custom depth
|
||||
_, file, line, _ := runtime.Caller(callerDepth)
|
||||
file = filepath.Base(file)
|
||||
|
||||
// Format message
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Format final log line
|
||||
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
|
||||
time.Now().Format(timeFormat),
|
||||
string(level),
|
||||
@@ -138,7 +121,6 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
|
||||
bl.logger.Println(logLine)
|
||||
}
|
||||
|
||||
// IsInitialized returns whether the base logger is initialized
|
||||
func (bl *BaseLogger) IsInitialized() bool {
|
||||
if bl == nil {
|
||||
return false
|
||||
@@ -148,19 +130,16 @@ func (bl *BaseLogger) IsInitialized() bool {
|
||||
return bl.initialized
|
||||
}
|
||||
|
||||
// RecoverAndLog recovers from panics and logs them
|
||||
func RecoverAndLog() {
|
||||
baseLogger := GetBaseLogger("panic")
|
||||
if baseLogger != nil && baseLogger.IsInitialized() {
|
||||
if r := recover(); r != nil {
|
||||
// Get stack trace
|
||||
buf := make([]byte, 4096)
|
||||
n := runtime.Stack(buf, false)
|
||||
stackTrace := string(buf[:n])
|
||||
|
||||
baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace)
|
||||
|
||||
// Re-panic to maintain original behavior if needed
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DebugLogger handles debug-level logging
|
||||
type DebugLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewDebugLogger creates a new debug logger instance
|
||||
func NewDebugLogger() *DebugLogger {
|
||||
base, _ := InitializeBase("debug")
|
||||
return &DebugLogger{
|
||||
@@ -19,14 +17,12 @@ func NewDebugLogger() *DebugLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes a debug-level log entry
|
||||
func (dl *DebugLogger) Log(format string, v ...interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes a debug-level log entry with additional context
|
||||
func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if dl.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -34,7 +30,6 @@ func (dl *DebugLogger) LogWithContext(context string, format string, v ...interf
|
||||
}
|
||||
}
|
||||
|
||||
// LogFunction logs function entry and exit for debugging
|
||||
func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
|
||||
if dl.base != nil {
|
||||
if len(args) > 0 {
|
||||
@@ -45,21 +40,18 @@ func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogVariable logs variable values for debugging
|
||||
func (dl *DebugLogger) LogVariable(varName string, value interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value)
|
||||
}
|
||||
}
|
||||
|
||||
// LogState logs application state information
|
||||
func (dl *DebugLogger) LogState(component string, state interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state)
|
||||
}
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries for debugging
|
||||
func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
|
||||
if dl.base != nil {
|
||||
if len(args) > 0 {
|
||||
@@ -70,7 +62,6 @@ func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogMemory logs memory usage information
|
||||
func (dl *DebugLogger) LogMemory() {
|
||||
if dl.base != nil {
|
||||
var m runtime.MemStats
|
||||
@@ -80,32 +71,27 @@ func (dl *DebugLogger) LogMemory() {
|
||||
}
|
||||
}
|
||||
|
||||
// LogGoroutines logs current number of goroutines
|
||||
func (dl *DebugLogger) LogGoroutines() {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine())
|
||||
}
|
||||
}
|
||||
|
||||
// LogTiming logs timing information for performance debugging
|
||||
func (dl *DebugLogger) LogTiming(operation string, duration interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to convert bytes to kilobytes
|
||||
func bToKb(b uint64) uint64 {
|
||||
return b / 1024
|
||||
}
|
||||
|
||||
// Global debug logger instance
|
||||
var (
|
||||
debugLogger *DebugLogger
|
||||
debugOnce sync.Once
|
||||
)
|
||||
|
||||
// GetDebugLogger returns the global debug logger instance
|
||||
func GetDebugLogger() *DebugLogger {
|
||||
debugOnce.Do(func() {
|
||||
debugLogger = NewDebugLogger()
|
||||
@@ -113,47 +99,38 @@ func GetDebugLogger() *DebugLogger {
|
||||
return debugLogger
|
||||
}
|
||||
|
||||
// Debug logs a debug-level message using the global debug logger
|
||||
func Debug(format string, v ...interface{}) {
|
||||
GetDebugLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// DebugWithContext logs a debug-level message with context using the global debug logger
|
||||
func DebugWithContext(context string, format string, v ...interface{}) {
|
||||
GetDebugLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// DebugFunction logs function entry and exit using the global debug logger
|
||||
func DebugFunction(functionName string, args ...interface{}) {
|
||||
GetDebugLogger().LogFunction(functionName, args...)
|
||||
}
|
||||
|
||||
// DebugVariable logs variable values using the global debug logger
|
||||
func DebugVariable(varName string, value interface{}) {
|
||||
GetDebugLogger().LogVariable(varName, value)
|
||||
}
|
||||
|
||||
// DebugState logs application state information using the global debug logger
|
||||
func DebugState(component string, state interface{}) {
|
||||
GetDebugLogger().LogState(component, state)
|
||||
}
|
||||
|
||||
// DebugSQL logs SQL queries using the global debug logger
|
||||
func DebugSQL(query string, args ...interface{}) {
|
||||
GetDebugLogger().LogSQL(query, args...)
|
||||
}
|
||||
|
||||
// DebugMemory logs memory usage information using the global debug logger
|
||||
func DebugMemory() {
|
||||
GetDebugLogger().LogMemory()
|
||||
}
|
||||
|
||||
// DebugGoroutines logs current number of goroutines using the global debug logger
|
||||
func DebugGoroutines() {
|
||||
GetDebugLogger().LogGoroutines()
|
||||
}
|
||||
|
||||
// DebugTiming logs timing information using the global debug logger
|
||||
func DebugTiming(operation string, duration interface{}) {
|
||||
GetDebugLogger().LogTiming(operation, duration)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrorLogger handles error-level logging
|
||||
type ErrorLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewErrorLogger creates a new error logger instance
|
||||
func NewErrorLogger() *ErrorLogger {
|
||||
base, _ := InitializeBase("error")
|
||||
return &ErrorLogger{
|
||||
@@ -19,14 +17,12 @@ func NewErrorLogger() *ErrorLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes an error-level log entry
|
||||
func (el *ErrorLogger) Log(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
el.base.Log(LogLevelError, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes an error-level log entry with additional context
|
||||
func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -34,7 +30,6 @@ func (el *ErrorLogger) LogWithContext(context string, format string, v ...interf
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error object with optional message
|
||||
func (el *ErrorLogger) LogError(err error, message ...string) {
|
||||
if el.base != nil && err != nil {
|
||||
if len(message) > 0 {
|
||||
@@ -45,7 +40,6 @@ func (el *ErrorLogger) LogError(err error, message ...string) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithStackTrace logs an error with stack trace
|
||||
func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
// Get stack trace
|
||||
@@ -58,7 +52,6 @@ func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogFatal logs a fatal error and exits the program
|
||||
func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
el.base.Log(LogLevelError, "[FATAL] "+format, v...)
|
||||
@@ -66,13 +59,11 @@ func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// Global error logger instance
|
||||
var (
|
||||
errorLogger *ErrorLogger
|
||||
errorOnce sync.Once
|
||||
)
|
||||
|
||||
// GetErrorLogger returns the global error logger instance
|
||||
func GetErrorLogger() *ErrorLogger {
|
||||
errorOnce.Do(func() {
|
||||
errorLogger = NewErrorLogger()
|
||||
@@ -80,27 +71,22 @@ func GetErrorLogger() *ErrorLogger {
|
||||
return errorLogger
|
||||
}
|
||||
|
||||
// Error logs an error-level message using the global error logger
|
||||
func Error(format string, v ...interface{}) {
|
||||
GetErrorLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// ErrorWithContext logs an error-level message with context using the global error logger
|
||||
func ErrorWithContext(context string, format string, v ...interface{}) {
|
||||
GetErrorLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// LogError logs an error object using the global error logger
|
||||
func LogError(err error, message ...string) {
|
||||
GetErrorLogger().LogError(err, message...)
|
||||
}
|
||||
|
||||
// ErrorWithStackTrace logs an error with stack trace using the global error logger
|
||||
func ErrorWithStackTrace(format string, v ...interface{}) {
|
||||
GetErrorLogger().LogWithStackTrace(format, v...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal error and exits the program using the global error logger
|
||||
func Fatal(format string, v ...interface{}) {
|
||||
GetErrorLogger().LogFatal(format, v...)
|
||||
}
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// InfoLogger handles info-level logging
|
||||
type InfoLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewInfoLogger creates a new info logger instance
|
||||
func NewInfoLogger() *InfoLogger {
|
||||
base, _ := InitializeBase("info")
|
||||
return &InfoLogger{
|
||||
@@ -18,14 +16,12 @@ func NewInfoLogger() *InfoLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes an info-level log entry
|
||||
func (il *InfoLogger) Log(format string, v ...interface{}) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes an info-level log entry with additional context
|
||||
func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if il.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -33,55 +29,47 @@ func (il *InfoLogger) LogWithContext(context string, format string, v ...interfa
|
||||
}
|
||||
}
|
||||
|
||||
// LogStartup logs application startup information
|
||||
func (il *InfoLogger) LogStartup(component string, message string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogShutdown logs application shutdown information
|
||||
func (il *InfoLogger) LogShutdown(component string, message string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogOperation logs general operation information
|
||||
func (il *InfoLogger) LogOperation(operation string, details string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details)
|
||||
}
|
||||
}
|
||||
|
||||
// LogStatus logs status changes or updates
|
||||
func (il *InfoLogger) LogStatus(component string, status string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status)
|
||||
}
|
||||
}
|
||||
|
||||
// LogRequest logs incoming requests
|
||||
func (il *InfoLogger) LogRequest(method string, path string, userAgent string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// LogResponse logs outgoing responses
|
||||
func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Global info logger instance
|
||||
var (
|
||||
infoLogger *InfoLogger
|
||||
infoOnce sync.Once
|
||||
)
|
||||
|
||||
// GetInfoLogger returns the global info logger instance
|
||||
func GetInfoLogger() *InfoLogger {
|
||||
infoOnce.Do(func() {
|
||||
infoLogger = NewInfoLogger()
|
||||
@@ -89,42 +77,34 @@ func GetInfoLogger() *InfoLogger {
|
||||
return infoLogger
|
||||
}
|
||||
|
||||
// Info logs an info-level message using the global info logger
|
||||
func Info(format string, v ...interface{}) {
|
||||
GetInfoLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// InfoWithContext logs an info-level message with context using the global info logger
|
||||
func InfoWithContext(context string, format string, v ...interface{}) {
|
||||
GetInfoLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// InfoStartup logs application startup information using the global info logger
|
||||
func InfoStartup(component string, message string) {
|
||||
GetInfoLogger().LogStartup(component, message)
|
||||
}
|
||||
|
||||
// InfoShutdown logs application shutdown information using the global info logger
|
||||
func InfoShutdown(component string, message string) {
|
||||
GetInfoLogger().LogShutdown(component, message)
|
||||
}
|
||||
|
||||
// InfoOperation logs general operation information using the global info logger
|
||||
func InfoOperation(operation string, details string) {
|
||||
GetInfoLogger().LogOperation(operation, details)
|
||||
}
|
||||
|
||||
// InfoStatus logs status changes or updates using the global info logger
|
||||
func InfoStatus(component string, status string) {
|
||||
GetInfoLogger().LogStatus(component, status)
|
||||
}
|
||||
|
||||
// InfoRequest logs incoming requests using the global info logger
|
||||
func InfoRequest(method string, path string, userAgent string) {
|
||||
GetInfoLogger().LogRequest(method, path, userAgent)
|
||||
}
|
||||
|
||||
// InfoResponse logs outgoing responses using the global info logger
|
||||
func InfoResponse(method string, path string, statusCode int, duration string) {
|
||||
GetInfoLogger().LogResponse(method, path, statusCode, duration)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// Legacy logger for backward compatibility
|
||||
logger *Logger
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// Logger maintains backward compatibility with existing code
|
||||
type Logger struct {
|
||||
base *BaseLogger
|
||||
errorLogger *ErrorLogger
|
||||
@@ -20,8 +18,6 @@ type Logger struct {
|
||||
debugLogger *DebugLogger
|
||||
}
|
||||
|
||||
// Initialize creates or gets the singleton logger instance
|
||||
// This maintains backward compatibility with existing code
|
||||
func Initialize() (*Logger, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
@@ -31,13 +27,11 @@ func Initialize() (*Logger, error) {
|
||||
}
|
||||
|
||||
func newLogger() (*Logger, error) {
|
||||
// Initialize the base logger
|
||||
baseLogger, err := InitializeBase("log")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the legacy logger wrapper
|
||||
logger := &Logger{
|
||||
base: baseLogger,
|
||||
errorLogger: GetErrorLogger(),
|
||||
@@ -49,7 +43,6 @@ func newLogger() (*Logger, error) {
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// Close closes the logger
|
||||
func (l *Logger) Close() error {
|
||||
if l.base != nil {
|
||||
return l.base.Close()
|
||||
@@ -57,7 +50,6 @@ func (l *Logger) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Legacy methods for backward compatibility
|
||||
func (l *Logger) log(level, format string, v ...interface{}) {
|
||||
if l.base != nil {
|
||||
l.base.LogWithCaller(LogLevel(level), 3, format, v...)
|
||||
@@ -94,13 +86,10 @@ func (l *Logger) Panic(format string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Global convenience functions for backward compatibility
|
||||
// These are now implemented in individual logger files to avoid redeclaration
|
||||
func LegacyInfo(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Info(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetInfoLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -109,7 +98,6 @@ func LegacyError(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Error(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetErrorLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -118,7 +106,6 @@ func LegacyWarn(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Warn(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetWarnLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -127,7 +114,6 @@ func LegacyDebug(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Debug(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetDebugLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -136,55 +122,42 @@ func Panic(format string) {
|
||||
if logger != nil {
|
||||
logger.Panic(format)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetErrorLogger().LogFatal(format)
|
||||
}
|
||||
}
|
||||
|
||||
// Enhanced logging convenience functions
|
||||
// These provide direct access to specialized logging functions
|
||||
|
||||
// LogStartup logs application startup information
|
||||
func LogStartup(component string, message string) {
|
||||
GetInfoLogger().LogStartup(component, message)
|
||||
}
|
||||
|
||||
// LogShutdown logs application shutdown information
|
||||
func LogShutdown(component string, message string) {
|
||||
GetInfoLogger().LogShutdown(component, message)
|
||||
}
|
||||
|
||||
// LogOperation logs general operation information
|
||||
func LogOperation(operation string, details string) {
|
||||
GetInfoLogger().LogOperation(operation, details)
|
||||
}
|
||||
|
||||
// LogRequest logs incoming HTTP requests
|
||||
func LogRequest(method string, path string, userAgent string) {
|
||||
GetInfoLogger().LogRequest(method, path, userAgent)
|
||||
}
|
||||
|
||||
// LogResponse logs outgoing HTTP responses
|
||||
func LogResponse(method string, path string, statusCode int, duration string) {
|
||||
GetInfoLogger().LogResponse(method, path, statusCode, duration)
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries for debugging
|
||||
func LogSQL(query string, args ...interface{}) {
|
||||
GetDebugLogger().LogSQL(query, args...)
|
||||
}
|
||||
|
||||
// LogMemory logs memory usage information
|
||||
func LogMemory() {
|
||||
GetDebugLogger().LogMemory()
|
||||
}
|
||||
|
||||
// LogTiming logs timing information for performance debugging
|
||||
func LogTiming(operation string, duration interface{}) {
|
||||
GetDebugLogger().LogTiming(operation, duration)
|
||||
}
|
||||
|
||||
// GetLegacyLogger returns the legacy logger instance for backward compatibility
|
||||
func GetLegacyLogger() *Logger {
|
||||
if logger == nil {
|
||||
logger, _ = Initialize()
|
||||
@@ -192,21 +165,17 @@ func GetLegacyLogger() *Logger {
|
||||
return logger
|
||||
}
|
||||
|
||||
// InitializeLogging initializes all logging components
|
||||
func InitializeLogging() error {
|
||||
// Initialize legacy logger for backward compatibility
|
||||
_, err := Initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize legacy logger: %v", err)
|
||||
}
|
||||
|
||||
// Pre-initialize all logger types to ensure separate log files
|
||||
GetErrorLogger()
|
||||
GetWarnLogger()
|
||||
GetInfoLogger()
|
||||
GetDebugLogger()
|
||||
|
||||
// Log successful initialization
|
||||
Info("Logging system initialized successfully")
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// WarnLogger handles warn-level logging
|
||||
type WarnLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewWarnLogger creates a new warn logger instance
|
||||
func NewWarnLogger() *WarnLogger {
|
||||
base, _ := InitializeBase("warn")
|
||||
return &WarnLogger{
|
||||
@@ -18,14 +16,12 @@ func NewWarnLogger() *WarnLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes a warn-level log entry
|
||||
func (wl *WarnLogger) Log(format string, v ...interface{}) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes a warn-level log entry with additional context
|
||||
func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if wl.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -33,7 +29,6 @@ func (wl *WarnLogger) LogWithContext(context string, format string, v ...interfa
|
||||
}
|
||||
}
|
||||
|
||||
// LogDeprecation logs a deprecation warning
|
||||
func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
|
||||
if wl.base != nil {
|
||||
if alternative != "" {
|
||||
@@ -44,27 +39,23 @@ func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogConfiguration logs configuration-related warnings
|
||||
func (wl *WarnLogger) LogConfiguration(setting string, message string) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogPerformance logs performance-related warnings
|
||||
func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Global warn logger instance
|
||||
var (
|
||||
warnLogger *WarnLogger
|
||||
warnOnce sync.Once
|
||||
)
|
||||
|
||||
// GetWarnLogger returns the global warn logger instance
|
||||
func GetWarnLogger() *WarnLogger {
|
||||
warnOnce.Do(func() {
|
||||
warnLogger = NewWarnLogger()
|
||||
@@ -72,27 +63,22 @@ func GetWarnLogger() *WarnLogger {
|
||||
return warnLogger
|
||||
}
|
||||
|
||||
// Warn logs a warn-level message using the global warn logger
|
||||
func Warn(format string, v ...interface{}) {
|
||||
GetWarnLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// WarnWithContext logs a warn-level message with context using the global warn logger
|
||||
func WarnWithContext(context string, format string, v ...interface{}) {
|
||||
GetWarnLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// WarnDeprecation logs a deprecation warning using the global warn logger
|
||||
func WarnDeprecation(feature string, alternative string) {
|
||||
GetWarnLogger().LogDeprecation(feature, alternative)
|
||||
}
|
||||
|
||||
// WarnConfiguration logs configuration-related warnings using the global warn logger
|
||||
func WarnConfiguration(setting string, message string) {
|
||||
GetWarnLogger().LogConfiguration(setting, message)
|
||||
}
|
||||
|
||||
// WarnPerformance logs performance-related warnings using the global warn logger
|
||||
func WarnPerformance(operation string, threshold string, actual string) {
|
||||
GetWarnLogger().LogPerformance(operation, threshold, actual)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// IsPortAvailable checks if a port is available for both TCP and UDP
|
||||
func IsPortAvailable(port int) bool {
|
||||
return IsTCPPortAvailable(port) && IsUDPPortAvailable(port)
|
||||
}
|
||||
|
||||
// IsTCPPortAvailable checks if a TCP port is available
|
||||
func IsTCPPortAvailable(port int) bool {
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
@@ -22,7 +20,6 @@ func IsTCPPortAvailable(port int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsUDPPortAvailable checks if a UDP port is available
|
||||
func IsUDPPortAvailable(port int) bool {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
@@ -32,7 +29,6 @@ func IsUDPPortAvailable(port int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// FindAvailablePort finds an available port starting from the given port
|
||||
func FindAvailablePort(startPort int) (int, error) {
|
||||
maxPort := 65535
|
||||
for port := startPort; port <= maxPort; port++ {
|
||||
@@ -43,14 +39,12 @@ func FindAvailablePort(startPort int) (int, error) {
|
||||
return 0, fmt.Errorf("no available ports found between %d and %d", startPort, maxPort)
|
||||
}
|
||||
|
||||
// FindAvailablePortRange finds a range of consecutive available ports
|
||||
func FindAvailablePortRange(startPort, count int) ([]int, error) {
|
||||
maxPort := 65535
|
||||
ports := make([]int, 0, count)
|
||||
currentPort := startPort
|
||||
|
||||
for len(ports) < count && currentPort <= maxPort {
|
||||
// Check if we have enough consecutive ports available
|
||||
available := true
|
||||
for i := 0; i < count-len(ports); i++ {
|
||||
if !IsPortAvailable(currentPort + i) {
|
||||
@@ -74,7 +68,6 @@ func FindAvailablePortRange(startPort, count int) ([]int, error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
// WaitForPortAvailable waits for a port to become available with timeout
|
||||
func WaitForPortAvailable(port int, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
@@ -84,4 +77,4 @@ func WaitForPortAvailable(port int, timeout time.Duration) error {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for port %d to become available", port)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,13 +8,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// MinPasswordLength defines the minimum password length
|
||||
MinPasswordLength = 8
|
||||
// BcryptCost defines the cost factor for bcrypt hashing
|
||||
BcryptCost = 12
|
||||
BcryptCost = 12
|
||||
)
|
||||
|
||||
// HashPassword hashes a plain text password using bcrypt
|
||||
func HashPassword(password string) (string, error) {
|
||||
if len(password) < MinPasswordLength {
|
||||
return "", errors.New("password must be at least 8 characters long")
|
||||
@@ -28,12 +25,10 @@ func HashPassword(password string) (string, error) {
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a plain text password against a hashed password
|
||||
func VerifyPassword(hashedPassword, password string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
}
|
||||
|
||||
// ValidatePasswordStrength validates password complexity requirements
|
||||
func ValidatePasswordStrength(password string) error {
|
||||
if len(password) < MinPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
|
||||
@@ -17,25 +17,23 @@ import (
|
||||
func Start(di *dig.Container) *fiber.App {
|
||||
app := fiber.New(fiber.Config{
|
||||
EnablePrintRoutes: true,
|
||||
ReadTimeout: 20 * time.Minute, // Increased for long-running Steam operations
|
||||
WriteTimeout: 20 * time.Minute, // Increased for long-running Steam operations
|
||||
IdleTimeout: 25 * time.Minute, // Increased accordingly
|
||||
BodyLimit: 10 * 1024 * 1024, // 10MB
|
||||
ReadTimeout: 20 * time.Minute,
|
||||
WriteTimeout: 20 * time.Minute,
|
||||
IdleTimeout: 25 * time.Minute,
|
||||
BodyLimit: 10 * 1024 * 1024,
|
||||
})
|
||||
|
||||
// Initialize security middleware
|
||||
securityMW := security.NewSecurityMiddleware()
|
||||
|
||||
// Add security middleware stack
|
||||
app.Use(securityMW.SecurityHeaders())
|
||||
app.Use(securityMW.LogSecurityEvents())
|
||||
app.Use(securityMW.TimeoutMiddleware(20 * time.Minute)) // Increased for Steam operations
|
||||
app.Use(securityMW.RequestContextTimeout(20 * time.Minute)) // Increased for Steam operations
|
||||
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
|
||||
app.Use(securityMW.TimeoutMiddleware(20 * time.Minute))
|
||||
app.Use(securityMW.RequestContextTimeout(20 * time.Minute))
|
||||
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024))
|
||||
app.Use(securityMW.ValidateUserAgent())
|
||||
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
|
||||
app.Use(securityMW.InputSanitization())
|
||||
app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global
|
||||
app.Use(securityMW.RateLimit(100, 1*time.Minute))
|
||||
|
||||
app.Use(helmet.New())
|
||||
|
||||
@@ -62,7 +60,7 @@ func Start(di *dig.Container) *fiber.App {
|
||||
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "3000" // Default port
|
||||
port = "3000"
|
||||
}
|
||||
|
||||
logging.Info("Starting server on port %s", port)
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type LogTailer struct {
|
||||
filePath string
|
||||
handleLine func(string)
|
||||
stopChan chan struct{}
|
||||
isRunning bool
|
||||
tracker *PositionTracker
|
||||
filePath string
|
||||
handleLine func(string)
|
||||
stopChan chan struct{}
|
||||
isRunning bool
|
||||
tracker *PositionTracker
|
||||
}
|
||||
|
||||
func NewLogTailer(filePath string, handleLine func(string)) *LogTailer {
|
||||
@@ -30,10 +30,9 @@ func (t *LogTailer) Start() {
|
||||
t.isRunning = true
|
||||
|
||||
go func() {
|
||||
// Load last position from tracker
|
||||
pos, err := t.tracker.LoadPosition()
|
||||
if err != nil {
|
||||
pos = &LogPosition{} // Start from beginning if error
|
||||
pos = &LogPosition{}
|
||||
}
|
||||
lastSize := pos.LastPosition
|
||||
|
||||
@@ -43,7 +42,6 @@ func (t *LogTailer) Start() {
|
||||
t.isRunning = false
|
||||
return
|
||||
default:
|
||||
// Try to open and read the file
|
||||
if file, err := os.Open(t.filePath); err == nil {
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
@@ -52,12 +50,10 @@ func (t *LogTailer) Start() {
|
||||
continue
|
||||
}
|
||||
|
||||
// If file was truncated, start from beginning
|
||||
if stat.Size() < lastSize {
|
||||
lastSize = 0
|
||||
}
|
||||
|
||||
// Seek to last read position
|
||||
if lastSize > 0 {
|
||||
file.Seek(lastSize, 0)
|
||||
}
|
||||
@@ -66,9 +62,8 @@ func (t *LogTailer) Start() {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
t.handleLine(line)
|
||||
lastSize, _ = file.Seek(0, 1) // Get current position
|
||||
|
||||
// Save position periodically
|
||||
lastSize, _ = file.Seek(0, 1)
|
||||
|
||||
t.tracker.SavePosition(&LogPosition{
|
||||
LastPosition: lastSize,
|
||||
LastRead: line,
|
||||
@@ -78,7 +73,6 @@ func (t *LogTailer) Start() {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
// Wait before next attempt
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
@@ -90,4 +84,4 @@ func (t *LogTailer) Stop() {
|
||||
return
|
||||
}
|
||||
close(t.stopChan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
)
|
||||
|
||||
type LogPosition struct {
|
||||
LastPosition int64 `json:"last_position"`
|
||||
LastRead string `json:"last_read"`
|
||||
LastPosition int64 `json:"last_position"`
|
||||
LastRead string `json:"last_read"`
|
||||
}
|
||||
|
||||
type PositionTracker struct {
|
||||
@@ -16,11 +16,10 @@ type PositionTracker struct {
|
||||
}
|
||||
|
||||
func NewPositionTracker(logPath string) *PositionTracker {
|
||||
// Create position file in same directory as log file
|
||||
dir := filepath.Dir(logPath)
|
||||
base := filepath.Base(logPath)
|
||||
positionFile := filepath.Join(dir, "."+base+".position")
|
||||
|
||||
|
||||
return &PositionTracker{
|
||||
positionFile: positionFile,
|
||||
}
|
||||
@@ -30,7 +29,6 @@ func (t *PositionTracker) LoadPosition() (*LogPosition, error) {
|
||||
data, err := os.ReadFile(t.positionFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Return empty position if file doesn't exist
|
||||
return &LogPosition{}, nil
|
||||
}
|
||||
return nil, err
|
||||
@@ -51,4 +49,4 @@ func (t *PositionTracker) SavePosition(pos *LogPosition) error {
|
||||
}
|
||||
|
||||
return os.WriteFile(t.positionFile, data, 0644)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func TailLogFile(path string, callback func(string)) {
|
||||
file, _ := os.Open(path)
|
||||
defer file.Close()
|
||||
|
||||
file.Seek(0, os.SEEK_END) // Start at end of file
|
||||
file.Seek(0, os.SEEK_END)
|
||||
reader := bufio.NewReader(file)
|
||||
|
||||
for {
|
||||
@@ -88,7 +88,7 @@ func TailLogFile(path string, callback func(string)) {
|
||||
if err == nil {
|
||||
callback(line)
|
||||
} else {
|
||||
time.Sleep(500 * time.Millisecond) // wait for new data
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
169
local/utl/websocket/websocket.go
Normal file
169
local/utl/websocket/websocket.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebSocketConnection struct {
|
||||
conn *websocket.Conn
|
||||
serverID *uuid.UUID
|
||||
userID *uuid.UUID
|
||||
}
|
||||
|
||||
type WebSocketService struct {
|
||||
connections sync.Map
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewWebSocketService() *WebSocketService {
|
||||
return &WebSocketService{}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
|
||||
wsConn := &WebSocketConnection{
|
||||
conn: conn,
|
||||
userID: userID,
|
||||
}
|
||||
ws.connections.Store(connID, wsConn)
|
||||
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) RemoveConnection(connID string) {
|
||||
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
wsConn.conn.Close()
|
||||
}
|
||||
}
|
||||
logging.Info("WebSocket connection removed: %s", connID)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
|
||||
if conn, exists := ws.connections.Load(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
wsConn.serverID = &serverID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
|
||||
stepMsg := model.StepMessage{
|
||||
Step: step,
|
||||
Status: status,
|
||||
Message: message,
|
||||
Error: errorMsg,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeStep,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: stepMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
|
||||
steamMsg := model.SteamOutputMessage{
|
||||
Output: output,
|
||||
IsError: isError,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeSteamOutput,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: steamMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
|
||||
errorMsg := model.ErrorMessage{
|
||||
Error: error,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeError,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: errorMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
|
||||
completeMsg := model.CompleteMessage{
|
||||
ServerID: serverID,
|
||||
Success: success,
|
||||
Message: message,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeComplete,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: completeMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logging.Error("Failed to marshal WebSocket message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if wsConn.serverID != nil && *wsConn.serverID == serverID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logging.Error("Failed to marshal WebSocket message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if wsConn.userID != nil && *wsConn.userID == userID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) GetActiveConnections() int {
|
||||
count := 0
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package swagger Code generated by swaggo/swag. DO NOT EDIT
|
||||
package swagger
|
||||
|
||||
import "github.com/swaggo/swag"
|
||||
@@ -2239,7 +2238,6 @@ const docTemplate = `{
|
||||
}
|
||||
}`
|
||||
|
||||
// SwaggerInfo holds exported Swagger Info so clients can modify it
|
||||
var SwaggerInfo = &swag.Spec{
|
||||
Version: "1.0",
|
||||
Host: "acc-api.jurmanovic.com",
|
||||
|
||||
@@ -10,24 +10,19 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GenerateTestToken creates a JWT token for testing purposes
|
||||
func GenerateTestToken() (string, error) {
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "test_user",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Use the environment JWT_SECRET for consistency with middleware
|
||||
testSecret := os.Getenv("JWT_SECRET")
|
||||
if testSecret == "" {
|
||||
// Fallback to a test secret if env var is not set
|
||||
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(testSecret)
|
||||
|
||||
// Generate JWT token
|
||||
token, err := jwtHandler.GenerateToken(user.ID.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate test token: %w", err)
|
||||
@@ -36,8 +31,6 @@ func GenerateTestToken() (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// MustGenerateTestToken generates a test token and panics if it fails
|
||||
// This is useful for test setup where failing to generate a token is a fatal error
|
||||
func MustGenerateTestToken() string {
|
||||
token, err := GenerateTestToken()
|
||||
if err != nil {
|
||||
@@ -46,24 +39,19 @@ func MustGenerateTestToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// GenerateTestTokenWithExpiry creates a JWT token with a specific expiry time
|
||||
func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
|
||||
// Use the environment JWT_SECRET for consistency with middleware
|
||||
testSecret := os.Getenv("JWT_SECRET")
|
||||
if testSecret == "" {
|
||||
// Fallback to a test secret if env var is not set
|
||||
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(testSecret)
|
||||
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "test_user",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Generate JWT token with custom expiry
|
||||
token, err := jwtHandler.GenerateTokenWithExpiry(user, expiryTime)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate test token with expiry: %w", err)
|
||||
@@ -72,8 +60,6 @@ func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// AddAuthHeader adds a test auth token to the request headers
|
||||
// This is a convenience method for tests that need to authenticate requests
|
||||
func AddAuthHeader(headers map[string]string) (map[string]string, error) {
|
||||
token, err := GenerateTestToken()
|
||||
if err != nil {
|
||||
@@ -88,7 +74,6 @@ func AddAuthHeader(headers map[string]string) (map[string]string, error) {
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// MustAddAuthHeader adds a test auth token to the request headers and panics if it fails
|
||||
func MustAddAuthHeader(headers map[string]string) map[string]string {
|
||||
result, err := AddAuthHeader(headers)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,26 +8,20 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockAuthMiddleware provides a test implementation of AuthMiddleware
|
||||
// that can be used as a drop-in replacement for the real AuthMiddleware
|
||||
type MockAuthMiddleware struct{}
|
||||
|
||||
// NewMockAuthMiddleware creates a new MockAuthMiddleware
|
||||
func NewMockAuthMiddleware() *MockAuthMiddleware {
|
||||
return &MockAuthMiddleware{}
|
||||
}
|
||||
|
||||
// Authenticate is a middleware that allows all requests without authentication for testing
|
||||
func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
// Set a mock user ID in context
|
||||
mockUserID := uuid.New().String()
|
||||
ctx.Locals("userID", mockUserID)
|
||||
|
||||
// Set mock user info
|
||||
mockUserInfo := &middleware.CachedUserInfo{
|
||||
UserID: mockUserID,
|
||||
Username: "test_user",
|
||||
RoleName: "Admin", // Admin role to bypass permission checks
|
||||
RoleName: "Admin",
|
||||
Permissions: map[string]bool{"*": true},
|
||||
CachedAt: time.Now(),
|
||||
}
|
||||
@@ -38,21 +32,18 @@ func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
// HasPermission is a middleware that allows all permission checks to pass for testing
|
||||
func (m *MockAuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit is a test implementation that allows all requests
|
||||
func (m *MockAuthMiddleware) AuthRateLimit() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireHTTPS is a test implementation that allows all HTTP requests
|
||||
func (m *MockAuthMiddleware) RequireHTTPS() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockConfigRepository provides a mock implementation of ConfigRepository
|
||||
type MockConfigRepository struct {
|
||||
configs map[string]*model.Config
|
||||
shouldFailGet bool
|
||||
@@ -21,7 +20,6 @@ func NewMockConfigRepository() *MockConfigRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig mocks the UpdateConfig method
|
||||
func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
|
||||
if m.shouldFailUpdate {
|
||||
return nil
|
||||
@@ -36,18 +34,15 @@ func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.C
|
||||
return config
|
||||
}
|
||||
|
||||
// SetShouldFailUpdate configures the mock to fail on UpdateConfig calls
|
||||
func (m *MockConfigRepository) SetShouldFailUpdate(shouldFail bool) {
|
||||
m.shouldFailUpdate = shouldFail
|
||||
}
|
||||
|
||||
// GetConfig retrieves a config by server ID and config file
|
||||
func (m *MockConfigRepository) GetConfig(serverID uuid.UUID, configFile string) *model.Config {
|
||||
key := serverID.String() + "_" + configFile
|
||||
return m.configs[key]
|
||||
}
|
||||
|
||||
// MockServerRepository provides a mock implementation of ServerRepository
|
||||
type MockServerRepository struct {
|
||||
servers map[uuid.UUID]*model.Server
|
||||
shouldFailGet bool
|
||||
@@ -59,7 +54,6 @@ func NewMockServerRepository() *MockServerRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// GetByID mocks the GetByID method
|
||||
func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*model.Server, error) {
|
||||
if m.shouldFailGet {
|
||||
return nil, errors.New("server not found")
|
||||
@@ -88,17 +82,14 @@ func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*mo
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// AddServer adds a server to the mock repository
|
||||
func (m *MockServerRepository) AddServer(server *model.Server) {
|
||||
m.servers[server.ID] = server
|
||||
}
|
||||
|
||||
// SetShouldFailGet configures the mock to fail on GetByID calls
|
||||
func (m *MockServerRepository) SetShouldFailGet(shouldFail bool) {
|
||||
m.shouldFailGet = shouldFail
|
||||
}
|
||||
|
||||
// MockServerService provides a mock implementation of ServerService
|
||||
type MockServerService struct {
|
||||
startRuntimeCalled bool
|
||||
startRuntimeServer *model.Server
|
||||
@@ -108,23 +99,19 @@ func NewMockServerService() *MockServerService {
|
||||
return &MockServerService{}
|
||||
}
|
||||
|
||||
// StartAccServerRuntime mocks the StartAccServerRuntime method
|
||||
func (m *MockServerService) StartAccServerRuntime(server *model.Server) {
|
||||
m.startRuntimeCalled = true
|
||||
m.startRuntimeServer = server
|
||||
}
|
||||
|
||||
// WasStartRuntimeCalled returns whether StartAccServerRuntime was called
|
||||
func (m *MockServerService) WasStartRuntimeCalled() bool {
|
||||
return m.startRuntimeCalled
|
||||
}
|
||||
|
||||
// GetStartRuntimeServer returns the server passed to StartAccServerRuntime
|
||||
func (m *MockServerService) GetStartRuntimeServer() *model.Server {
|
||||
return m.startRuntimeServer
|
||||
}
|
||||
|
||||
// Reset resets the mock state
|
||||
func (m *MockServerService) Reset() {
|
||||
m.startRuntimeCalled = false
|
||||
m.startRuntimeServer = nil
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockStateHistoryRepository provides a mock implementation of StateHistoryRepository
|
||||
type MockStateHistoryRepository struct {
|
||||
stateHistories []model.StateHistory
|
||||
shouldFailGet bool
|
||||
@@ -21,7 +20,6 @@ func NewMockStateHistoryRepository() *MockStateHistoryRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll mocks the GetAll method
|
||||
func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
|
||||
if m.shouldFailGet {
|
||||
return nil, errors.New("failed to get state history")
|
||||
@@ -37,13 +35,11 @@ func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.S
|
||||
return &filtered, nil
|
||||
}
|
||||
|
||||
// Insert mocks the Insert method
|
||||
func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *model.StateHistory) error {
|
||||
if m.shouldFailInsert {
|
||||
return errors.New("failed to insert state history")
|
||||
}
|
||||
|
||||
// Simulate BeforeCreate hook
|
||||
if stateHistory.ID == uuid.Nil {
|
||||
stateHistory.ID = uuid.New()
|
||||
}
|
||||
@@ -55,7 +51,6 @@ func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *m
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastSessionID mocks the GetLastSessionID method
|
||||
func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
|
||||
for i := len(m.stateHistories) - 1; i >= 0; i-- {
|
||||
if m.stateHistories[i].ServerID == serverID {
|
||||
@@ -65,7 +60,6 @@ func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serve
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
// Helper methods for filtering
|
||||
func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter *model.StateHistoryFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
@@ -93,7 +87,6 @@ func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper methods for testing configuration
|
||||
func (m *MockStateHistoryRepository) SetShouldFailGet(shouldFail bool) {
|
||||
m.shouldFailGet = shouldFail
|
||||
}
|
||||
@@ -102,7 +95,6 @@ func (m *MockStateHistoryRepository) SetShouldFailInsert(shouldFail bool) {
|
||||
m.shouldFailInsert = shouldFail
|
||||
}
|
||||
|
||||
// AddStateHistory adds a state history entry to the mock repository
|
||||
func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHistory) {
|
||||
if stateHistory.ID == uuid.Nil {
|
||||
stateHistory.ID = uuid.New()
|
||||
@@ -113,22 +105,18 @@ func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHis
|
||||
m.stateHistories = append(m.stateHistories, stateHistory)
|
||||
}
|
||||
|
||||
// GetCount returns the number of state history entries
|
||||
func (m *MockStateHistoryRepository) GetCount() int {
|
||||
return len(m.stateHistories)
|
||||
}
|
||||
|
||||
// Clear removes all state history entries
|
||||
func (m *MockStateHistoryRepository) Clear() {
|
||||
m.stateHistories = make([]model.StateHistory, 0)
|
||||
}
|
||||
|
||||
// GetSummaryStats calculates peak players, total sessions, and average players for mock data
|
||||
func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
|
||||
var stats model.StateHistoryStats
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
@@ -139,7 +127,6 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// Calculate statistics
|
||||
sessionMap := make(map[string]bool)
|
||||
totalPlayers := 0
|
||||
|
||||
@@ -159,11 +146,9 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetTotalPlaytime calculates total playtime in minutes for mock data
|
||||
func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
@@ -174,7 +159,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Group by session and calculate durations
|
||||
sessionMap := make(map[string][]model.StateHistory)
|
||||
for _, entry := range filteredEntries {
|
||||
sessionID := entry.SessionID.String()
|
||||
@@ -184,7 +168,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
|
||||
totalMinutes := 0
|
||||
for _, sessionEntries := range sessionMap {
|
||||
if len(sessionEntries) > 1 {
|
||||
// Sort by date (simple approach for mock)
|
||||
minTime := sessionEntries[0].DateCreated
|
||||
maxTime := sessionEntries[0].DateCreated
|
||||
hasPlayers := false
|
||||
@@ -211,26 +194,22 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
|
||||
return totalMinutes, nil
|
||||
}
|
||||
|
||||
// GetPlayerCountOverTime returns downsampled player count data for mock
|
||||
func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
|
||||
var points []model.PlayerCountPoint
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by hour (simple mock implementation)
|
||||
hourMap := make(map[string][]int)
|
||||
for _, entry := range filteredEntries {
|
||||
hourKey := entry.DateCreated.Format("2006-01-02 15")
|
||||
hourMap[hourKey] = append(hourMap[hourKey], entry.PlayerCount)
|
||||
}
|
||||
|
||||
// Calculate averages per hour
|
||||
for hourKey, counts := range hourMap {
|
||||
total := 0
|
||||
for _, count := range counts {
|
||||
@@ -247,20 +226,17 @@ func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context,
|
||||
return points, nil
|
||||
}
|
||||
|
||||
// GetSessionTypes counts sessions by type for mock
|
||||
func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
|
||||
var sessionTypes []model.SessionCount
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by session type
|
||||
sessionMap := make(map[model.TrackSession]map[string]bool) // session -> sessionID -> bool
|
||||
sessionMap := make(map[model.TrackSession]map[string]bool)
|
||||
for _, entry := range filteredEntries {
|
||||
if sessionMap[entry.Session] == nil {
|
||||
sessionMap[entry.Session] = make(map[string]bool)
|
||||
@@ -268,7 +244,6 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
|
||||
sessionMap[entry.Session][entry.SessionID.String()] = true
|
||||
}
|
||||
|
||||
// Count unique sessions per type
|
||||
for sessionType, sessions := range sessionMap {
|
||||
sessionTypes = append(sessionTypes, model.SessionCount{
|
||||
Name: sessionType,
|
||||
@@ -279,20 +254,17 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
|
||||
return sessionTypes, nil
|
||||
}
|
||||
|
||||
// GetDailyActivity counts sessions per day for mock
|
||||
func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
|
||||
var dailyActivity []model.DailyActivity
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by day
|
||||
dayMap := make(map[string]map[string]bool) // date -> sessionID -> bool
|
||||
dayMap := make(map[string]map[string]bool)
|
||||
for _, entry := range filteredEntries {
|
||||
dateKey := entry.DateCreated.Format("2006-01-02")
|
||||
if dayMap[dateKey] == nil {
|
||||
@@ -301,7 +273,6 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
|
||||
dayMap[dateKey][entry.SessionID.String()] = true
|
||||
}
|
||||
|
||||
// Count unique sessions per day
|
||||
for date, sessions := range dayMap {
|
||||
dailyActivity = append(dailyActivity, model.DailyActivity{
|
||||
Date: date,
|
||||
@@ -312,26 +283,22 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
|
||||
return dailyActivity, nil
|
||||
}
|
||||
|
||||
// GetRecentSessions retrieves recent sessions for mock
|
||||
func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
|
||||
var recentSessions []model.RecentSession
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by session
|
||||
sessionMap := make(map[string][]model.StateHistory)
|
||||
for _, entry := range filteredEntries {
|
||||
sessionID := entry.SessionID.String()
|
||||
sessionMap[sessionID] = append(sessionMap[sessionID], entry)
|
||||
}
|
||||
|
||||
// Create recent sessions (limit to 10)
|
||||
count := 0
|
||||
for _, entries := range sessionMap {
|
||||
if count >= 10 {
|
||||
@@ -339,7 +306,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
// Find min/max dates and max players
|
||||
minDate := entries[0].DateCreated
|
||||
maxDate := entries[0].DateCreated
|
||||
maxPlayers := 0
|
||||
@@ -356,7 +322,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
|
||||
}
|
||||
}
|
||||
|
||||
// Only include sessions with players
|
||||
if maxPlayers > 0 {
|
||||
duration := int(maxDate.Sub(minDate).Minutes())
|
||||
recentSessions = append(recentSessions, model.RecentSession{
|
||||
|
||||
@@ -24,14 +24,12 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// TestHelper provides utilities for testing
|
||||
type TestHelper struct {
|
||||
DB *gorm.DB
|
||||
TempDir string
|
||||
TestData *TestData
|
||||
}
|
||||
|
||||
// TestData contains common test data structures
|
||||
type TestData struct {
|
||||
ServerID uuid.UUID
|
||||
Server *model.Server
|
||||
@@ -39,37 +37,29 @@ type TestData struct {
|
||||
SampleConfig *model.Configuration
|
||||
}
|
||||
|
||||
// SetTestEnv sets the required environment variables for tests
|
||||
func SetTestEnv() {
|
||||
// Set required environment variables for testing
|
||||
os.Setenv("APP_SECRET", "test-secret-key-for-testing-123456")
|
||||
os.Setenv("APP_SECRET_CODE", "test-code-for-testing-123456789012")
|
||||
os.Setenv("ENCRYPTION_KEY", "12345678901234567890123456789012")
|
||||
os.Setenv("JWT_SECRET", "test-jwt-secret-key-for-testing-123456789012345678901234567890")
|
||||
os.Setenv("ACCESS_KEY", "test-access-key-for-testing")
|
||||
// Set test-specific environment variables
|
||||
os.Setenv("TESTING_ENV", "true") // Used to bypass
|
||||
os.Setenv("TESTING_ENV", "true")
|
||||
|
||||
configs.Init()
|
||||
}
|
||||
|
||||
// NewTestHelper creates a new test helper with in-memory database
|
||||
func NewTestHelper(t *testing.T) *TestHelper {
|
||||
// Set required environment variables
|
||||
SetTestEnv()
|
||||
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create in-memory SQLite database for testing
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent), // Suppress SQL logs in tests
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
|
||||
// Auto-migrate the schema
|
||||
err = db.AutoMigrate(
|
||||
&model.Server{},
|
||||
&model.Config{},
|
||||
@@ -79,7 +69,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
|
||||
&model.StateHistory{},
|
||||
)
|
||||
|
||||
// Explicitly ensure tables exist with correct structure
|
||||
if !db.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err = db.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -90,7 +79,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
|
||||
t.Fatalf("Failed to migrate test database: %v", err)
|
||||
}
|
||||
|
||||
// Create test data
|
||||
testData := createTestData(t, tempDir)
|
||||
|
||||
return &TestHelper{
|
||||
@@ -100,11 +88,9 @@ func NewTestHelper(t *testing.T) *TestHelper {
|
||||
}
|
||||
}
|
||||
|
||||
// createTestData creates common test data structures
|
||||
func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
serverID := uuid.New()
|
||||
|
||||
// Create sample server
|
||||
server := &model.Server{
|
||||
ID: serverID,
|
||||
Name: "Test Server",
|
||||
@@ -115,13 +101,11 @@ func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
FromSteamCMD: false,
|
||||
}
|
||||
|
||||
// Create server directory
|
||||
serverConfigDir := filepath.Join(tempDir, "server", "cfg")
|
||||
if err := os.MkdirAll(serverConfigDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create server config directory: %v", err)
|
||||
}
|
||||
|
||||
// Sample configuration files content
|
||||
configFiles := map[string]string{
|
||||
"configuration.json": `{
|
||||
"udpPort": "9231",
|
||||
@@ -212,7 +196,6 @@ func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
}`,
|
||||
}
|
||||
|
||||
// Sample configuration struct
|
||||
sampleConfig := &model.Configuration{
|
||||
UdpPort: model.IntString(9231),
|
||||
TcpPort: model.IntString(9232),
|
||||
@@ -230,14 +213,12 @@ func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTestConfigFiles creates actual config files in the test directory
|
||||
func (th *TestHelper) CreateTestConfigFiles() error {
|
||||
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
|
||||
|
||||
for filename, content := range th.TestData.ConfigFiles {
|
||||
filePath := filepath.Join(serverConfigDir, filename)
|
||||
|
||||
// Encode content to UTF-16 LE BOM format as expected by the application
|
||||
utf16Content, err := EncodeUTF16LEBOM([]byte(content))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -251,7 +232,6 @@ func (th *TestHelper) CreateTestConfigFiles() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMalformedConfigFile creates a config file with invalid JSON
|
||||
func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
|
||||
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
|
||||
filePath := filepath.Join(serverConfigDir, filename)
|
||||
@@ -259,67 +239,51 @@ func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
|
||||
malformedJSON := `{
|
||||
"udpPort": "9231",
|
||||
"tcpPort": "9232"
|
||||
"maxConnections": "30" // Missing comma - invalid JSON
|
||||
"maxConnections": "30"
|
||||
}`
|
||||
|
||||
return os.WriteFile(filePath, []byte(malformedJSON), 0644)
|
||||
}
|
||||
|
||||
// RemoveConfigFile removes a config file to simulate missing file scenarios
|
||||
func (th *TestHelper) RemoveConfigFile(filename string) error {
|
||||
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
|
||||
filePath := filepath.Join(serverConfigDir, filename)
|
||||
return os.Remove(filePath)
|
||||
}
|
||||
|
||||
// InsertTestServer inserts the test server into the database
|
||||
func (th *TestHelper) InsertTestServer() error {
|
||||
return th.DB.Create(th.TestData.Server).Error
|
||||
}
|
||||
|
||||
// CreateContext creates a test context
|
||||
func (th *TestHelper) CreateContext() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// CreateFiberCtx creates a fiber.Ctx for testing
|
||||
func (th *TestHelper) CreateFiberCtx() *fiber.Ctx {
|
||||
// Create app and request for fiber context
|
||||
app := fiber.New()
|
||||
// Create a dummy request that doesn't depend on external http objects
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
// Create the fiber context from real request/response
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
// Store the original request for release later
|
||||
ctx.Locals("original-request", req)
|
||||
// Return the context which can be safely used in tests
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ReleaseFiberCtx properly releases a fiber context created with CreateFiberCtx
|
||||
func (th *TestHelper) ReleaseFiberCtx(app *fiber.App, ctx *fiber.Ctx) {
|
||||
if app != nil && ctx != nil {
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup performs cleanup operations after tests
|
||||
func (th *TestHelper) Cleanup() {
|
||||
// Close database connection
|
||||
if sqlDB, err := th.DB.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
// Temporary directory is automatically cleaned up by t.TempDir()
|
||||
}
|
||||
|
||||
// LoadTestEnvFile loads environment variables from a .env file for testing
|
||||
func LoadTestEnvFile() error {
|
||||
// Try to load from .env file
|
||||
return godotenv.Load()
|
||||
}
|
||||
|
||||
// AssertNoError is a helper function to check for errors in tests
|
||||
func AssertNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
@@ -327,7 +291,6 @@ func AssertNoError(t *testing.T, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertError is a helper function to check for expected errors
|
||||
func AssertError(t *testing.T, err error, expectedMsg string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
@@ -338,7 +301,6 @@ func AssertError(t *testing.T, err error, expectedMsg string) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertEqual checks if two values are equal
|
||||
func AssertEqual(t *testing.T, expected, actual interface{}) {
|
||||
t.Helper()
|
||||
if expected != actual {
|
||||
@@ -346,7 +308,6 @@ func AssertEqual(t *testing.T, expected, actual interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertNotNil checks if a value is not nil
|
||||
func AssertNotNil(t *testing.T, value interface{}) {
|
||||
t.Helper()
|
||||
if value == nil {
|
||||
@@ -354,12 +315,9 @@ func AssertNotNil(t *testing.T, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertNil checks if a value is nil
|
||||
func AssertNil(t *testing.T, value interface{}) {
|
||||
t.Helper()
|
||||
if value != nil {
|
||||
// Special handling for interface values that contain nil but aren't nil themselves
|
||||
// For example, (*jwt.Claims)(nil) is not equal to nil, but it contains nil
|
||||
switch v := value.(type) {
|
||||
case *interface{}:
|
||||
if v == nil || *v == nil {
|
||||
@@ -374,13 +332,11 @@ func AssertNil(t *testing.T, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeUTF16LEBOM encodes UTF-8 bytes to UTF-16 LE BOM format
|
||||
func EncodeUTF16LEBOM(input []byte) ([]byte, error) {
|
||||
encoder := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
|
||||
return transformBytes(encoder.NewEncoder(), input)
|
||||
}
|
||||
|
||||
// transformBytes applies a transform to input bytes
|
||||
func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
w := transform.NewWriter(&buf, t)
|
||||
@@ -396,7 +352,6 @@ func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ErrorForTesting creates an error for testing purposes
|
||||
func ErrorForTesting(message string) error {
|
||||
return errors.New(message)
|
||||
}
|
||||
|
||||
9
tests/testdata/state_history_data.go
vendored
9
tests/testdata/state_history_data.go
vendored
@@ -7,13 +7,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// StateHistoryTestData provides simple test data generators
|
||||
type StateHistoryTestData struct {
|
||||
ServerID uuid.UUID
|
||||
BaseTime time.Time
|
||||
}
|
||||
|
||||
// NewStateHistoryTestData creates a new test data generator
|
||||
func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
|
||||
return &StateHistoryTestData{
|
||||
ServerID: serverID,
|
||||
@@ -21,7 +19,6 @@ func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateStateHistory creates a basic state history entry
|
||||
func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, track string, playerCount int, sessionID uuid.UUID) model.StateHistory {
|
||||
return model.StateHistory{
|
||||
ID: uuid.New(),
|
||||
@@ -36,7 +33,6 @@ func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, t
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMultipleEntries creates multiple state history entries for the same session
|
||||
func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession, track string, playerCounts []int) []model.StateHistory {
|
||||
sessionID := uuid.New()
|
||||
var entries []model.StateHistory
|
||||
@@ -59,7 +55,6 @@ func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession
|
||||
return entries
|
||||
}
|
||||
|
||||
// CreateBasicFilter creates a basic filter for testing
|
||||
func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
|
||||
return &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
@@ -68,7 +63,6 @@ func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateFilterWithSession creates a filter with session type
|
||||
func CreateFilterWithSession(serverID string, session model.TrackSession) *model.StateHistoryFilter {
|
||||
return &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
@@ -78,7 +72,6 @@ func CreateFilterWithSession(serverID string, session model.TrackSession) *model
|
||||
}
|
||||
}
|
||||
|
||||
// LogLines contains sample ACC server log lines for testing
|
||||
var SampleLogLines = []string{
|
||||
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE",
|
||||
"[2024-01-15 14:30:30.456] 1 client(s) online",
|
||||
@@ -95,7 +88,6 @@ var SampleLogLines = []string{
|
||||
"[2024-01-15 15:00:05.123] Session changed: RACE -> NONE",
|
||||
}
|
||||
|
||||
// ExpectedSessionChanges represents the expected session changes from parsing the sample log lines
|
||||
var ExpectedSessionChanges = []struct {
|
||||
From model.TrackSession
|
||||
To model.TrackSession
|
||||
@@ -106,5 +98,4 @@ var ExpectedSessionChanges = []struct {
|
||||
{model.SessionRace, model.SessionUnknown},
|
||||
}
|
||||
|
||||
// ExpectedPlayerCounts represents the expected player counts from parsing the sample log lines
|
||||
var ExpectedPlayerCounts = []int{1, 3, 5, 8, 12, 15, 14, 0}
|
||||
|
||||
@@ -14,12 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func TestController_JSONParsing_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test basic JSON parsing functionality
|
||||
app := fiber.New()
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
@@ -30,7 +28,6 @@ func TestController_JSONParsing_Success(t *testing.T) {
|
||||
return c.JSON(data)
|
||||
})
|
||||
|
||||
// Prepare test data
|
||||
testData := map[string]interface{}{
|
||||
"name": "test",
|
||||
"value": 123,
|
||||
@@ -38,34 +35,28 @@ func TestController_JSONParsing_Success(t *testing.T) {
|
||||
bodyBytes, err := json.Marshal(testData)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, "test", response["name"])
|
||||
tests.AssertEqual(t, float64(123), response["value"]) // JSON numbers are float64
|
||||
tests.AssertEqual(t, float64(123), response["value"])
|
||||
}
|
||||
|
||||
func TestController_JSONParsing_InvalidJSON(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test handling of invalid JSON
|
||||
app := fiber.New()
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
@@ -76,39 +67,31 @@ func TestController_JSONParsing_InvalidJSON(t *testing.T) {
|
||||
return c.JSON(data)
|
||||
})
|
||||
|
||||
// Create request with invalid JSON
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte("invalid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
// Parse error response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify error response
|
||||
tests.AssertEqual(t, "Invalid JSON", response["error"])
|
||||
}
|
||||
|
||||
func TestController_UUIDValidation_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test UUID parameter validation
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test/:id", func(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
|
||||
// Validate UUID
|
||||
if _, err := uuid.Parse(id); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
|
||||
}
|
||||
@@ -116,40 +99,33 @@ func TestController_UUIDValidation_Success(t *testing.T) {
|
||||
return c.JSON(fiber.Map{"id": id, "valid": true})
|
||||
})
|
||||
|
||||
// Create request with valid UUID
|
||||
validUUID := uuid.New().String()
|
||||
req := httptest.NewRequest("GET", "/test/"+validUUID, nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, validUUID, response["id"])
|
||||
tests.AssertEqual(t, true, response["valid"])
|
||||
}
|
||||
|
||||
func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test handling of invalid UUID
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test/:id", func(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
|
||||
// Validate UUID
|
||||
if _, err := uuid.Parse(id); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
|
||||
}
|
||||
@@ -157,32 +133,26 @@ func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
|
||||
return c.JSON(fiber.Map{"id": id, "valid": true})
|
||||
})
|
||||
|
||||
// Create request with invalid UUID
|
||||
req := httptest.NewRequest("GET", "/test/invalid-uuid", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
// Parse error response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify error response
|
||||
tests.AssertEqual(t, "Invalid UUID", response["error"])
|
||||
}
|
||||
|
||||
func TestController_QueryParameters_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test query parameter handling
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
@@ -197,34 +167,28 @@ func TestController_QueryParameters_Success(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
// Create request with query parameters
|
||||
req := httptest.NewRequest("GET", "/test?restart=true&override=false&format=xml", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, true, response["restart"])
|
||||
tests.AssertEqual(t, false, response["override"])
|
||||
tests.AssertEqual(t, "xml", response["format"])
|
||||
}
|
||||
|
||||
func TestController_HTTPMethods_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test different HTTP methods
|
||||
app := fiber.New()
|
||||
|
||||
var getCalled, postCalled, putCalled, deleteCalled bool
|
||||
@@ -249,28 +213,24 @@ func TestController_HTTPMethods_Success(t *testing.T) {
|
||||
return c.JSON(fiber.Map{"method": "DELETE"})
|
||||
})
|
||||
|
||||
// Test GET
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
tests.AssertEqual(t, true, getCalled)
|
||||
|
||||
// Test POST
|
||||
req = httptest.NewRequest("POST", "/test", nil)
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
tests.AssertEqual(t, true, postCalled)
|
||||
|
||||
// Test PUT
|
||||
req = httptest.NewRequest("PUT", "/test", nil)
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
tests.AssertEqual(t, true, putCalled)
|
||||
|
||||
// Test DELETE
|
||||
req = httptest.NewRequest("DELETE", "/test", nil)
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
@@ -279,12 +239,10 @@ func TestController_HTTPMethods_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestController_ErrorHandling_StatusCodes(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test different error status codes
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/400", func(c *fiber.Ctx) error {
|
||||
@@ -307,7 +265,6 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Internal Server Error"})
|
||||
})
|
||||
|
||||
// Test different status codes
|
||||
testCases := []struct {
|
||||
path string
|
||||
code int
|
||||
@@ -328,12 +285,10 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test Configuration model JSON serialization
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/config", func(c *fiber.Ctx) error {
|
||||
@@ -348,22 +303,18 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
|
||||
return c.JSON(config)
|
||||
})
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/config", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response model.Configuration
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, model.IntString(9231), response.UdpPort)
|
||||
tests.AssertEqual(t, model.IntString(9232), response.TcpPort)
|
||||
tests.AssertEqual(t, model.IntString(30), response.MaxConnections)
|
||||
@@ -373,73 +324,61 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestController_UserModel_JSONSerialization(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test User model JSON serialization (password should be hidden)
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/user", func(c *fiber.Ctx) error {
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
Password: "secret-password", // Should not appear in JSON
|
||||
Password: "secret-password",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/user", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response as raw JSON to check password is excluded
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify password field is not in JSON
|
||||
if bytes.Contains(body, []byte("password")) || bytes.Contains(body, []byte("secret-password")) {
|
||||
t.Fatal("Password should not be included in JSON response")
|
||||
}
|
||||
|
||||
// Verify other fields are present
|
||||
if !bytes.Contains(body, []byte("username")) || !bytes.Contains(body, []byte("testuser")) {
|
||||
t.Fatal("Username should be included in JSON response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_MiddlewareChaining_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test middleware chaining
|
||||
app := fiber.New()
|
||||
|
||||
var middleware1Called, middleware2Called, handlerCalled bool
|
||||
|
||||
// Middleware 1
|
||||
middleware1 := func(c *fiber.Ctx) error {
|
||||
middleware1Called = true
|
||||
c.Locals("middleware1", "executed")
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Middleware 2
|
||||
middleware2 := func(c *fiber.Ctx) error {
|
||||
middleware2Called = true
|
||||
c.Locals("middleware2", "executed")
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Handler
|
||||
handler := func(c *fiber.Ctx) error {
|
||||
handlerCalled = true
|
||||
return c.JSON(fiber.Map{
|
||||
@@ -451,27 +390,22 @@ func TestController_MiddlewareChaining_Success(t *testing.T) {
|
||||
|
||||
app.Get("/test", middleware1, middleware2, handler)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Verify all were called
|
||||
tests.AssertEqual(t, true, middleware1Called)
|
||||
tests.AssertEqual(t, true, middleware2Called)
|
||||
tests.AssertEqual(t, true, handlerCalled)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify middleware values were passed
|
||||
tests.AssertEqual(t, "executed", response["middleware1"])
|
||||
tests.AssertEqual(t, "executed", response["middleware2"])
|
||||
tests.AssertEqual(t, "executed", response["handler"])
|
||||
|
||||
@@ -11,27 +11,20 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// MockMiddleware simulates authentication for testing purposes
|
||||
type MockMiddleware struct{}
|
||||
|
||||
// GetTestAuthMiddleware returns a mock auth middleware that can be used in place of the real one
|
||||
// This works because we're adding real authentication tokens to requests
|
||||
func GetTestAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *middleware.AuthMiddleware {
|
||||
// Use environment JWT secrets for consistency with token generation
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
|
||||
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret) // Use same secret for test consistency
|
||||
|
||||
// Cast our mock to the real type for testing
|
||||
// This is a type-unsafe cast but works for testing because we're using real JWT tokens
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
|
||||
return middleware.NewAuthMiddleware(ms, cache, jwtHandler, openJWTHandler)
|
||||
}
|
||||
|
||||
// AddAuthToRequest adds a valid authentication token to a test request
|
||||
func AddAuthToRequest(req *fiber.Ctx) {
|
||||
token := tests.MustGenerateTestToken()
|
||||
req.Request().Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
@@ -23,13 +23,11 @@ import (
|
||||
)
|
||||
|
||||
func TestStateHistoryController_GetAll_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// No need for DisableAuthentication, we'll use real auth tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -44,33 +42,26 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -83,13 +74,11 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -104,7 +93,6 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data with different sessions
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
@@ -115,27 +103,21 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
|
||||
err = repo.Insert(helper.CreateContext(), &raceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with session filter and authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=R", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -148,13 +130,11 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -169,38 +149,30 @@ func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with no data and authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify empty response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
// Skip this test as it requires more complex setup
|
||||
t.Skip("Skipping test due to UUID validation issues")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -215,10 +187,8 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data with multiple entries for statistics
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Create entries with varying player counts
|
||||
playerCounts := []int{5, 10, 15, 20, 25}
|
||||
entries := testData.CreateMultipleEntries(model.SessionRace, "spa", playerCounts)
|
||||
|
||||
@@ -227,33 +197,26 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with valid serverID UUID
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add Authorization header for testing
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -261,7 +224,6 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
err = json.Unmarshal(body, &stats)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify statistics structure exists (actual calculation is tested in service layer)
|
||||
if stats.PeakPlayers < 0 {
|
||||
t.Error("Expected non-negative peak players")
|
||||
}
|
||||
@@ -274,16 +236,13 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
|
||||
// Skip this test as it requires more complex setup
|
||||
t.Skip("Skipping test due to UUID validation issues")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -298,33 +257,26 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with valid serverID UUID
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add Authorization header for testing
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -332,23 +284,19 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
|
||||
err = json.Unmarshal(body, &stats)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify empty statistics
|
||||
tests.AssertEqual(t, 0, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
|
||||
tests.AssertEqual(t, 0, stats.TotalSessions)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) {
|
||||
// Skip this test as it requires more complex setup
|
||||
t.Skip("Skipping test due to UUID validation issues")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -363,42 +311,34 @@ func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with invalid query parameters but with valid UUID
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s&min_players=invalid", validServerID), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add Authorization header for testing
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify error response
|
||||
tests.AssertEqual(t, http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_HTTPMethods(t *testing.T) {
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -413,36 +353,30 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Test that only GET method is allowed for GetAll
|
||||
req := httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
||||
// Test that only GET method is allowed for GetStatistics
|
||||
req = httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
||||
// Test that PUT method is not allowed
|
||||
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
||||
// Test that DELETE method is not allowed
|
||||
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
@@ -452,13 +386,11 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) {
|
||||
|
||||
func TestStateHistoryController_ContentType(t *testing.T) {
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -473,43 +405,36 @@ func TestStateHistoryController_ContentType(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Test GetAll endpoint with authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify content type is JSON
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
||||
}
|
||||
|
||||
// Test GetStatistics endpoint with authentication
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req = httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify content type is JSON
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
||||
@@ -517,16 +442,13 @@ func TestStateHistoryController_ContentType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
// Skip this test as it's problematic and would require deeper investigation
|
||||
t.Skip("Skipping test due to response structure issues that need further investigation")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
@@ -541,7 +463,6 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -549,21 +470,17 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Test GetAll response structure with authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err := app.Test(req)
|
||||
@@ -572,24 +489,19 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Log the actual response for debugging
|
||||
t.Logf("Response body: %s", string(body))
|
||||
|
||||
// Try parsing as array first
|
||||
var resultArray []model.StateHistory
|
||||
err = json.Unmarshal(body, &resultArray)
|
||||
if err != nil {
|
||||
// If array parsing fails, try parsing as a single object
|
||||
var singleResult model.StateHistory
|
||||
err = json.Unmarshal(body, &singleResult)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response as either array or object: %v", err)
|
||||
}
|
||||
// Convert single result to array
|
||||
resultArray = []model.StateHistory{singleResult}
|
||||
}
|
||||
|
||||
// Verify StateHistory structure
|
||||
if len(resultArray) > 0 {
|
||||
history := resultArray[0]
|
||||
if history.ID == uuid.Nil {
|
||||
|
||||
@@ -12,12 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func TestStateHistoryRepository_Insert_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -28,15 +26,12 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
|
||||
// Test Insert
|
||||
err := repo.Insert(ctx, &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify ID was generated
|
||||
tests.AssertNotNil(t, history.ID)
|
||||
if history.ID == uuid.Nil {
|
||||
t.Error("Expected non-nil ID after insert")
|
||||
@@ -44,12 +39,10 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -60,10 +53,8 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Insert multiple entries
|
||||
playerCounts := []int{0, 5, 10, 15, 10, 5, 0}
|
||||
entries := testData.CreateMultipleEntries(model.SessionPractice, "spa", playerCounts)
|
||||
|
||||
@@ -72,7 +63,6 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Test GetAll
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
result, err := repo.GetAll(ctx, filter)
|
||||
|
||||
@@ -82,12 +72,10 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -98,19 +86,16 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data with different sessions
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 15, uuid.New())
|
||||
|
||||
// Insert both
|
||||
err := repo.Insert(ctx, &practiceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
err = repo.Insert(ctx, &raceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Test GetAll with session filter
|
||||
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace)
|
||||
result, err := repo.GetAll(ctx, filter)
|
||||
|
||||
@@ -122,12 +107,10 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -138,44 +121,34 @@ func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Insert multiple entries with different session IDs
|
||||
sessionID1 := uuid.New()
|
||||
sessionID2 := uuid.New()
|
||||
|
||||
history1 := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID1)
|
||||
history2 := testData.CreateStateHistory(model.SessionRace, "spa", 10, sessionID2)
|
||||
|
||||
// Insert with a small delay to ensure ordering
|
||||
err := repo.Insert(ctx, &history1)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
time.Sleep(1 * time.Millisecond) // Ensure different timestamps
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
|
||||
err = repo.Insert(ctx, &history2)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Test GetLastSessionID - should return the most recent session ID
|
||||
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Should be sessionID2 since it was inserted last
|
||||
// We should get the most recently inserted session ID, but the exact value doesn't matter
|
||||
// Just check that it's not nil and that it's a valid UUID
|
||||
if lastSessionID == uuid.Nil {
|
||||
t.Fatal("Expected non-nil UUID for last session ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -186,19 +159,16 @@ func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Test GetLastSessionID with no data
|
||||
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, uuid.Nil, lastSessionID)
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -209,14 +179,11 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data with varying player counts
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Create entries with different sessions and player counts
|
||||
sessionID1 := uuid.New()
|
||||
sessionID2 := uuid.New()
|
||||
|
||||
// Practice session: 5, 10, 15 players
|
||||
practiceEntries := testData.CreateMultipleEntries(model.SessionPractice, "spa", []int{5, 10, 15})
|
||||
for i := range practiceEntries {
|
||||
practiceEntries[i].SessionID = sessionID1
|
||||
@@ -224,7 +191,6 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Race session: 20, 25, 30 players
|
||||
raceEntries := testData.CreateMultipleEntries(model.SessionRace, "spa", []int{20, 25, 30})
|
||||
for i := range raceEntries {
|
||||
raceEntries[i].SessionID = sessionID2
|
||||
@@ -232,17 +198,14 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Test GetSummaryStats
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
stats, err := repo.GetSummaryStats(ctx, filter)
|
||||
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify stats are calculated correctly
|
||||
tests.AssertEqual(t, 30, stats.PeakPlayers) // Maximum player count
|
||||
tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions
|
||||
tests.AssertEqual(t, 30, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 2, stats.TotalSessions)
|
||||
|
||||
// Average should be (5+10+15+20+25+30)/6 = 17.5
|
||||
expectedAverage := float64(5+10+15+20+25+30) / 6.0
|
||||
if stats.AveragePlayers != expectedAverage {
|
||||
t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers)
|
||||
@@ -250,12 +213,10 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -266,25 +227,21 @@ func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Test GetSummaryStats with no data
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
stats, err := repo.GetSummaryStats(ctx, filter)
|
||||
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify stats are zero for empty dataset
|
||||
tests.AssertEqual(t, 0, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
|
||||
tests.AssertEqual(t, 0, stats.TotalSessions)
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -294,13 +251,10 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data spanning a time range
|
||||
sessionID := uuid.New()
|
||||
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create entries spanning 1 hour with players > 0
|
||||
entries := []model.StateHistory{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
@@ -342,7 +296,6 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Test GetTotalPlaytime
|
||||
filter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: helper.TestData.ServerID.String(),
|
||||
@@ -355,20 +308,16 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
|
||||
playtime, err := repo.GetTotalPlaytime(ctx, filter)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Should calculate playtime based on session duration
|
||||
if playtime <= 0 {
|
||||
t.Error("Expected positive playtime for session with multiple entries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
// Test concurrent database operations
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -379,10 +328,8 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -390,7 +337,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create and insert initial entry to ensure table exists and is properly set up
|
||||
initialHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(ctx, &initialHistory)
|
||||
if err != nil {
|
||||
@@ -399,7 +345,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
|
||||
done := make(chan bool, 3)
|
||||
|
||||
// Concurrent inserts
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
@@ -412,7 +357,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent reads
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
@@ -425,7 +369,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent GetLastSessionID
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
@@ -437,19 +380,16 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
for i := 0; i < 3; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
|
||||
// Test edge cases with filters
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -460,15 +400,11 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Insert a test record to ensure the table is properly set up
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(ctx, &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Skip nil filter test as it might not be supported by the repository implementation
|
||||
|
||||
// Test with server ID filter - this should work
|
||||
serverFilter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: helper.TestData.ServerID.String(),
|
||||
@@ -478,7 +414,6 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, result)
|
||||
|
||||
// Test with invalid server ID in summary stats
|
||||
invalidFilter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: "invalid-uuid",
|
||||
|
||||
@@ -12,30 +12,24 @@ import (
|
||||
)
|
||||
|
||||
func TestJWT_GenerateAndValidateToken(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test JWT generation
|
||||
token, err := jwtHandler.GenerateToken(user.ID.String())
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, token)
|
||||
|
||||
// Verify token is not empty
|
||||
if token == "" {
|
||||
t.Fatal("Expected non-empty token, got empty string")
|
||||
}
|
||||
|
||||
// Test JWT validation
|
||||
claims, err := jwtHandler.ValidateToken(token)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, claims)
|
||||
@@ -43,80 +37,66 @@ func TestJWT_GenerateAndValidateToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestJWT_ValidateToken_InvalidToken(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Test with invalid token
|
||||
claims, err := jwtHandler.ValidateToken("invalid-token")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid token, got nil")
|
||||
}
|
||||
// Direct nil check to avoid the interface wrapping issue
|
||||
if claims != nil {
|
||||
t.Fatalf("Expected nil claims, got %v", claims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWT_ValidateToken_EmptyToken(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Test with empty token
|
||||
claims, err := jwtHandler.ValidateToken("")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty token, got nil")
|
||||
}
|
||||
// Direct nil check to avoid the interface wrapping issue
|
||||
if claims != nil {
|
||||
t.Fatalf("Expected nil claims, got %v", claims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_VerifyPassword_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Hash password manually (simulating what BeforeCreate would do)
|
||||
plainPassword := "password123"
|
||||
hashedPassword, err := password.HashPassword(plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
user.Password = hashedPassword
|
||||
|
||||
// Test password verification - should succeed
|
||||
err = user.VerifyPassword(plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
func TestUser_VerifyPassword_Failure(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Hash password manually
|
||||
hashedPassword, err := password.HashPassword("correct_password")
|
||||
tests.AssertNoError(t, err)
|
||||
user.Password = hashedPassword
|
||||
|
||||
// Test password verification with wrong password - should fail
|
||||
err = user.VerifyPassword("wrong_password")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for wrong password, got nil")
|
||||
@@ -124,11 +104,9 @@ func TestUser_VerifyPassword_Failure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_Validate_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create valid user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
@@ -136,25 +114,21 @@ func TestUser_Validate_Success(t *testing.T) {
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test validation - should succeed
|
||||
err := user.Validate()
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
func TestUser_Validate_MissingUsername(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create user without username
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "", // Missing username
|
||||
Username: "",
|
||||
Password: "password123",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test validation - should fail
|
||||
err := user.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing username, got nil")
|
||||
@@ -162,19 +136,16 @@ func TestUser_Validate_MissingUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_Validate_MissingPassword(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create user without password
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
Password: "", // Missing password
|
||||
Password: "",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test validation - should fail
|
||||
err := user.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing password, got nil")
|
||||
@@ -182,24 +153,19 @@ func TestUser_Validate_MissingPassword(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPassword_HashAndVerify(t *testing.T) {
|
||||
// Test password hashing and verification directly
|
||||
plainPassword := "test_password_123"
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := password.HashPassword(plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, hashedPassword)
|
||||
|
||||
// Verify hashed password is not the same as plain password
|
||||
if hashedPassword == plainPassword {
|
||||
t.Fatal("Hashed password should not equal plain password")
|
||||
}
|
||||
|
||||
// Verify correct password
|
||||
err = password.VerifyPassword(hashedPassword, plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify wrong password fails
|
||||
err = password.VerifyPassword(hashedPassword, "wrong_password")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for wrong password, got nil")
|
||||
@@ -215,7 +181,7 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
|
||||
{"Valid password", "StrongPassword123!", false},
|
||||
{"Too short", "123", true},
|
||||
{"Empty password", "", true},
|
||||
{"Medium password", "password123", false}, // Depends on validation rules
|
||||
{"Medium password", "password123", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -235,7 +201,6 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRole_Model(t *testing.T) {
|
||||
// Test Role model structure
|
||||
permissions := []model.Permission{
|
||||
{ID: uuid.New(), Name: "read"},
|
||||
{ID: uuid.New(), Name: "write"},
|
||||
@@ -248,7 +213,6 @@ func TestRole_Model(t *testing.T) {
|
||||
Permissions: permissions,
|
||||
}
|
||||
|
||||
// Verify role structure
|
||||
tests.AssertEqual(t, "Test Role", role.Name)
|
||||
tests.AssertEqual(t, 3, len(role.Permissions))
|
||||
tests.AssertEqual(t, "read", role.Permissions[0].Name)
|
||||
@@ -257,19 +221,16 @@ func TestRole_Model(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPermission_Model(t *testing.T) {
|
||||
// Test Permission model structure
|
||||
permission := &model.Permission{
|
||||
ID: uuid.New(),
|
||||
Name: "test_permission",
|
||||
}
|
||||
|
||||
// Verify permission structure
|
||||
tests.AssertEqual(t, "test_permission", permission.Name)
|
||||
tests.AssertNotNil(t, permission.ID)
|
||||
}
|
||||
|
||||
func TestUser_WithRole_Model(t *testing.T) {
|
||||
// Test User model with Role relationship
|
||||
permissions := []model.Permission{
|
||||
{ID: uuid.New(), Name: "read"},
|
||||
{ID: uuid.New(), Name: "write"},
|
||||
@@ -289,7 +250,6 @@ func TestUser_WithRole_Model(t *testing.T) {
|
||||
Role: role,
|
||||
}
|
||||
|
||||
// Verify user-role relationship
|
||||
tests.AssertEqual(t, "testuser", user.Username)
|
||||
tests.AssertEqual(t, role.ID, user.RoleID)
|
||||
tests.AssertEqual(t, "User", user.Role.Name)
|
||||
|
||||
@@ -11,28 +11,22 @@ import (
|
||||
)
|
||||
|
||||
func TestInMemoryCache_Set_Get_Success(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test data
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
duration := 5 * time.Minute
|
||||
|
||||
// Set value in cache
|
||||
c.Set(key, value, duration)
|
||||
|
||||
// Get value from cache
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, value, result)
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Get_NotFound(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Try to get non-existent key
|
||||
result, found := c.Get("non-existent-key")
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
@@ -41,43 +35,33 @@ func TestInMemoryCache_Get_NotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Set_Get_NoExpiration(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test data
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
|
||||
// Set value without expiration (duration = 0)
|
||||
c.Set(key, value, 0)
|
||||
|
||||
// Get value from cache
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, value, result)
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Expiration(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test data
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
duration := 1 * time.Millisecond // Very short duration
|
||||
duration := 1 * time.Millisecond
|
||||
|
||||
// Set value in cache
|
||||
c.Set(key, value, duration)
|
||||
|
||||
// Verify it's initially there
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, value, result)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
// Try to get expired value
|
||||
result, found = c.Get(key)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
@@ -86,26 +70,20 @@ func TestInMemoryCache_Expiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Delete(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test data
|
||||
key := "test-key"
|
||||
value := "test-value"
|
||||
duration := 5 * time.Minute
|
||||
|
||||
// Set value in cache
|
||||
c.Set(key, value, duration)
|
||||
|
||||
// Verify it's there
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, value, result)
|
||||
|
||||
// Delete the key
|
||||
c.Delete(key)
|
||||
|
||||
// Verify it's gone
|
||||
result, found = c.Get(key)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
@@ -114,37 +92,29 @@ func TestInMemoryCache_Delete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Overwrite(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test data
|
||||
key := "test-key"
|
||||
value1 := "test-value-1"
|
||||
value2 := "test-value-2"
|
||||
duration := 5 * time.Minute
|
||||
|
||||
// Set first value
|
||||
c.Set(key, value1, duration)
|
||||
|
||||
// Verify first value
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, value1, result)
|
||||
|
||||
// Overwrite with second value
|
||||
c.Set(key, value2, duration)
|
||||
|
||||
// Verify second value
|
||||
result, found = c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, value2, result)
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Multiple_Keys(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test data
|
||||
testData := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
@@ -152,22 +122,18 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) {
|
||||
}
|
||||
duration := 5 * time.Minute
|
||||
|
||||
// Set multiple values
|
||||
for key, value := range testData {
|
||||
c.Set(key, value, duration)
|
||||
}
|
||||
|
||||
// Verify all values
|
||||
for key, expectedValue := range testData {
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, expectedValue, result)
|
||||
}
|
||||
|
||||
// Delete one key
|
||||
c.Delete("key2")
|
||||
|
||||
// Verify key2 is gone but others remain
|
||||
result, found := c.Get("key2")
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
@@ -184,10 +150,8 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Complex_Objects(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test with complex object (User struct)
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
@@ -196,15 +160,12 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) {
|
||||
key := "user:" + user.ID.String()
|
||||
duration := 5 * time.Minute
|
||||
|
||||
// Set user in cache
|
||||
c.Set(key, user, duration)
|
||||
|
||||
// Get user from cache
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertNotNil(t, result)
|
||||
|
||||
// Verify it's the same user
|
||||
cachedUser, ok := result.(*model.User)
|
||||
tests.AssertEqual(t, true, ok)
|
||||
tests.AssertEqual(t, user.ID, cachedUser.ID)
|
||||
@@ -212,33 +173,27 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_GetOrSet_CacheHit(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Pre-populate cache
|
||||
key := "test-key"
|
||||
expectedValue := "cached-value"
|
||||
c.Set(key, expectedValue, 5*time.Minute)
|
||||
|
||||
// Track if fetcher is called
|
||||
fetcherCalled := false
|
||||
fetcher := func() (string, error) {
|
||||
fetcherCalled = true
|
||||
return "fetcher-value", nil
|
||||
}
|
||||
|
||||
// Use GetOrSet - should return cached value
|
||||
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, expectedValue, result)
|
||||
tests.AssertEqual(t, false, fetcherCalled) // Fetcher should not be called
|
||||
tests.AssertEqual(t, false, fetcherCalled)
|
||||
}
|
||||
|
||||
func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Track if fetcher is called
|
||||
fetcherCalled := false
|
||||
expectedValue := "fetcher-value"
|
||||
fetcher := func() (string, error) {
|
||||
@@ -248,35 +203,29 @@ func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) {
|
||||
|
||||
key := "test-key"
|
||||
|
||||
// Use GetOrSet - should call fetcher and cache result
|
||||
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, expectedValue, result)
|
||||
tests.AssertEqual(t, true, fetcherCalled) // Fetcher should be called
|
||||
tests.AssertEqual(t, true, fetcherCalled)
|
||||
|
||||
// Verify value is now cached
|
||||
cachedResult, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertEqual(t, expectedValue, cachedResult)
|
||||
}
|
||||
|
||||
func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Fetcher that returns error
|
||||
fetcher := func() (string, error) {
|
||||
return "", tests.ErrorForTesting("fetcher error")
|
||||
}
|
||||
|
||||
key := "test-key"
|
||||
|
||||
// Use GetOrSet - should return error
|
||||
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
|
||||
tests.AssertError(t, err, "")
|
||||
tests.AssertEqual(t, "", result)
|
||||
|
||||
// Verify nothing is cached
|
||||
cachedResult, found := c.Get(key)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if cachedResult != nil {
|
||||
@@ -285,10 +234,8 @@ func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_TypeSafety(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test type safety with GetOrSet
|
||||
userFetcher := func() (*model.User, error) {
|
||||
return &model.User{
|
||||
ID: uuid.New(),
|
||||
@@ -298,13 +245,11 @@ func TestInMemoryCache_TypeSafety(t *testing.T) {
|
||||
|
||||
key := "user-key"
|
||||
|
||||
// Use GetOrSet with User type
|
||||
user, err := cache.GetOrSet(c, key, 5*time.Minute, userFetcher)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, user)
|
||||
tests.AssertEqual(t, "testuser", user.Username)
|
||||
|
||||
// Verify correct type is cached
|
||||
cachedResult, found := c.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
cachedUser, ok := cachedResult.(*model.User)
|
||||
@@ -313,26 +258,21 @@ func TestInMemoryCache_TypeSafety(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInMemoryCache_Concurrent_Access(t *testing.T) {
|
||||
// Setup
|
||||
c := cache.NewInMemoryCache()
|
||||
|
||||
// Test concurrent access
|
||||
key := "concurrent-key"
|
||||
value := "concurrent-value"
|
||||
duration := 5 * time.Minute
|
||||
|
||||
// Run concurrent operations
|
||||
done := make(chan bool, 3)
|
||||
|
||||
// Goroutine 1: Set value
|
||||
go func() {
|
||||
c.Set(key, value, duration)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: Get value
|
||||
go func() {
|
||||
time.Sleep(1 * time.Millisecond) // Small delay to ensure Set happens first
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
result, found := c.Get(key)
|
||||
if found {
|
||||
tests.AssertEqual(t, value, result)
|
||||
@@ -340,19 +280,16 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) {
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 3: Delete value
|
||||
go func() {
|
||||
time.Sleep(2 * time.Millisecond) // Delay to ensure Set and Get happen first
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
c.Delete(key)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 3; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify value is deleted
|
||||
result, found := c.Get(key)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
@@ -361,7 +298,6 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
@@ -371,14 +307,12 @@ func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) {
|
||||
|
||||
serviceName := "test-service"
|
||||
|
||||
// Initial call - should need refresh
|
||||
status, needsRefresh := cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, model.StatusUnknown, status)
|
||||
tests.AssertEqual(t, true, needsRefresh)
|
||||
}
|
||||
|
||||
func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
@@ -389,17 +323,14 @@ func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) {
|
||||
serviceName := "test-service"
|
||||
expectedStatus := model.StatusRunning
|
||||
|
||||
// Update status
|
||||
cache.UpdateStatus(serviceName, expectedStatus)
|
||||
|
||||
// Get status - should return cached value
|
||||
status, needsRefresh := cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, expectedStatus, status)
|
||||
tests.AssertEqual(t, false, needsRefresh)
|
||||
}
|
||||
|
||||
func TestServerStatusCache_Throttling(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 100 * time.Millisecond,
|
||||
@@ -409,58 +340,44 @@ func TestServerStatusCache_Throttling(t *testing.T) {
|
||||
|
||||
serviceName := "test-service"
|
||||
|
||||
// Update status
|
||||
cache.UpdateStatus(serviceName, model.StatusRunning)
|
||||
|
||||
// Immediate call - should return cached value
|
||||
status, needsRefresh := cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, model.StatusRunning, status)
|
||||
tests.AssertEqual(t, false, needsRefresh)
|
||||
|
||||
// Call within throttle time - should return cached/default status
|
||||
status, needsRefresh = cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, model.StatusRunning, status)
|
||||
tests.AssertEqual(t, false, needsRefresh)
|
||||
|
||||
// Wait for throttle time to pass
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Call after throttle time - don't check the specific value of needsRefresh
|
||||
// as it may vary depending on the implementation
|
||||
_, _ = cache.GetStatus(serviceName)
|
||||
|
||||
// Test passes if we reach this point without errors
|
||||
}
|
||||
|
||||
func TestServerStatusCache_Expiration(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 50 * time.Millisecond, // Very short expiration
|
||||
ExpirationTime: 50 * time.Millisecond,
|
||||
ThrottleTime: 10 * time.Millisecond,
|
||||
DefaultStatus: model.StatusUnknown,
|
||||
}
|
||||
cache := model.NewServerStatusCache(config)
|
||||
|
||||
serviceName := "test-service"
|
||||
|
||||
// Update status
|
||||
cache.UpdateStatus(serviceName, model.StatusRunning)
|
||||
|
||||
// Immediate call - should return cached value
|
||||
status, needsRefresh := cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, model.StatusRunning, status)
|
||||
tests.AssertEqual(t, false, needsRefresh)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Call after expiration - should need refresh
|
||||
status, needsRefresh = cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, true, needsRefresh)
|
||||
}
|
||||
|
||||
func TestServerStatusCache_InvalidateStatus(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
@@ -470,25 +387,20 @@ func TestServerStatusCache_InvalidateStatus(t *testing.T) {
|
||||
|
||||
serviceName := "test-service"
|
||||
|
||||
// Update status
|
||||
cache.UpdateStatus(serviceName, model.StatusRunning)
|
||||
|
||||
// Verify it's cached
|
||||
status, needsRefresh := cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, model.StatusRunning, status)
|
||||
tests.AssertEqual(t, false, needsRefresh)
|
||||
|
||||
// Invalidate status
|
||||
cache.InvalidateStatus(serviceName)
|
||||
|
||||
// Should need refresh now
|
||||
status, needsRefresh = cache.GetStatus(serviceName)
|
||||
tests.AssertEqual(t, model.StatusUnknown, status)
|
||||
tests.AssertEqual(t, true, needsRefresh)
|
||||
}
|
||||
|
||||
func TestServerStatusCache_Clear(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
@@ -496,23 +408,19 @@ func TestServerStatusCache_Clear(t *testing.T) {
|
||||
}
|
||||
cache := model.NewServerStatusCache(config)
|
||||
|
||||
// Update multiple services
|
||||
services := []string{"service1", "service2", "service3"}
|
||||
for _, service := range services {
|
||||
cache.UpdateStatus(service, model.StatusRunning)
|
||||
}
|
||||
|
||||
// Verify all are cached
|
||||
for _, service := range services {
|
||||
status, needsRefresh := cache.GetStatus(service)
|
||||
tests.AssertEqual(t, model.StatusRunning, status)
|
||||
tests.AssertEqual(t, false, needsRefresh)
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
cache.Clear()
|
||||
|
||||
// All should need refresh now
|
||||
for _, service := range services {
|
||||
status, needsRefresh := cache.GetStatus(service)
|
||||
tests.AssertEqual(t, model.StatusUnknown, status)
|
||||
@@ -521,30 +429,23 @@ func TestServerStatusCache_Clear(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLookupCache_SetGetClear(t *testing.T) {
|
||||
// Setup
|
||||
cache := model.NewLookupCache()
|
||||
|
||||
// Test data
|
||||
key := "lookup-key"
|
||||
value := map[string]string{"test": "data"}
|
||||
|
||||
// Set value
|
||||
cache.Set(key, value)
|
||||
|
||||
// Get value
|
||||
result, found := cache.Get(key)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertNotNil(t, result)
|
||||
|
||||
// Verify it's the same data
|
||||
resultMap, ok := result.(map[string]string)
|
||||
tests.AssertEqual(t, true, ok)
|
||||
tests.AssertEqual(t, "data", resultMap["test"])
|
||||
|
||||
// Clear cache
|
||||
cache.Clear()
|
||||
|
||||
// Should be gone now
|
||||
result, found = cache.Get(key)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
@@ -553,7 +454,6 @@ func TestLookupCache_SetGetClear(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerConfigCache_Configuration(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
@@ -571,17 +471,14 @@ func TestServerConfigCache_Configuration(t *testing.T) {
|
||||
ConfigVersion: model.IntString(1),
|
||||
}
|
||||
|
||||
// Initial get - should miss
|
||||
result, found := cache.GetConfiguration(serverID)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if result != nil {
|
||||
t.Fatal("Expected nil result, got non-nil")
|
||||
}
|
||||
|
||||
// Update cache
|
||||
cache.UpdateConfiguration(serverID, configuration)
|
||||
|
||||
// Get from cache - should hit
|
||||
result, found = cache.GetConfiguration(serverID)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertNotNil(t, result)
|
||||
@@ -590,7 +487,6 @@ func TestServerConfigCache_Configuration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
|
||||
// Setup
|
||||
config := model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
@@ -602,11 +498,9 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
|
||||
configuration := model.Configuration{UdpPort: model.IntString(9231)}
|
||||
assistRules := model.AssistRules{StabilityControlLevelMax: model.IntString(0)}
|
||||
|
||||
// Update multiple configs for server
|
||||
cache.UpdateConfiguration(serverID, configuration)
|
||||
cache.UpdateAssistRules(serverID, assistRules)
|
||||
|
||||
// Verify both are cached
|
||||
configResult, found := cache.GetConfiguration(serverID)
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertNotNil(t, configResult)
|
||||
@@ -615,10 +509,8 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
|
||||
tests.AssertEqual(t, true, found)
|
||||
tests.AssertNotNil(t, assistResult)
|
||||
|
||||
// Invalidate server cache
|
||||
cache.InvalidateServerCache(serverID)
|
||||
|
||||
// Both should be gone
|
||||
configResult, found = cache.GetConfiguration(serverID)
|
||||
tests.AssertEqual(t, false, found)
|
||||
if configResult != nil {
|
||||
|
||||
@@ -12,25 +12,20 @@ import (
|
||||
)
|
||||
|
||||
func TestConfigService_GetConfiguration_ValidFile(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Test GetConfiguration
|
||||
config, err := configService.GetConfiguration(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, config)
|
||||
|
||||
// Verify the result is the expected configuration
|
||||
tests.AssertEqual(t, model.IntString(9231), config.UdpPort)
|
||||
tests.AssertEqual(t, model.IntString(9232), config.TcpPort)
|
||||
tests.AssertEqual(t, model.IntString(30), config.MaxConnections)
|
||||
@@ -40,21 +35,17 @@ func TestConfigService_GetConfiguration_ValidFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_GetConfiguration_MissingFile(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create server directory but no config files
|
||||
serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg")
|
||||
err := os.MkdirAll(serverConfigDir, 0755)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Test GetConfiguration for missing file
|
||||
config, err := configService.GetConfiguration(helper.TestData.Server)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing file, got nil")
|
||||
@@ -65,25 +56,20 @@ func TestConfigService_GetConfiguration_MissingFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_GetEventConfig_ValidFile(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Test GetEventConfig
|
||||
eventConfig, err := configService.GetEventConfig(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, eventConfig)
|
||||
|
||||
// Verify the result is the expected event configuration
|
||||
tests.AssertEqual(t, "spa", eventConfig.Track)
|
||||
tests.AssertEqual(t, model.IntString(80), eventConfig.PreRaceWaitingTimeSeconds)
|
||||
tests.AssertEqual(t, model.IntString(120), eventConfig.SessionOverTimeSeconds)
|
||||
@@ -91,7 +77,6 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) {
|
||||
tests.AssertEqual(t, float64(0.3), eventConfig.CloudLevel)
|
||||
tests.AssertEqual(t, float64(0.0), eventConfig.Rain)
|
||||
|
||||
// Verify sessions
|
||||
tests.AssertEqual(t, 3, len(eventConfig.Sessions))
|
||||
if len(eventConfig.Sessions) > 0 {
|
||||
tests.AssertEqual(t, model.SessionPractice, eventConfig.Sessions[0].SessionType)
|
||||
@@ -101,20 +86,16 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) {
|
||||
|
||||
func TestConfigService_SaveConfiguration_Success(t *testing.T) {
|
||||
t.Skip("Temporarily disabled due to path issues")
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Prepare new configuration
|
||||
newConfig := &model.Configuration{
|
||||
UdpPort: model.IntString(9999),
|
||||
TcpPort: model.IntString(10000),
|
||||
@@ -124,16 +105,13 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) {
|
||||
ConfigVersion: model.IntString(2),
|
||||
}
|
||||
|
||||
// Test SaveConfiguration
|
||||
err = configService.SaveConfiguration(helper.TestData.Server, newConfig)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify the configuration was saved
|
||||
configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json")
|
||||
fileContent, err := os.ReadFile(configPath)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Convert from UTF-16 to UTF-8 for verification
|
||||
utf8Content, err := service.DecodeUTF16LEBOM(fileContent)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -141,7 +119,6 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) {
|
||||
err = json.Unmarshal(utf8Content, &savedConfig)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify the saved values
|
||||
tests.AssertEqual(t, "9999", savedConfig["udpPort"])
|
||||
tests.AssertEqual(t, "10000", savedConfig["tcpPort"])
|
||||
tests.AssertEqual(t, "40", savedConfig["maxConnections"])
|
||||
@@ -151,25 +128,20 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_LoadConfigs_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Test LoadConfigs
|
||||
configs, err := configService.LoadConfigs(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, configs)
|
||||
|
||||
// Verify all configurations are loaded
|
||||
tests.AssertEqual(t, model.IntString(9231), configs.Configuration.UdpPort)
|
||||
tests.AssertEqual(t, model.IntString(9232), configs.Configuration.TcpPort)
|
||||
tests.AssertEqual(t, "Test ACC Server", configs.Settings.ServerName)
|
||||
@@ -183,21 +155,17 @@ func TestConfigService_LoadConfigs_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create server directory but no config files
|
||||
serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg")
|
||||
err := os.MkdirAll(serverConfigDir, 0755)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Test LoadConfigs with missing files
|
||||
configs, err := configService.LoadConfigs(helper.TestData.Server)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing files, got nil")
|
||||
@@ -208,20 +176,16 @@ func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_MalformedJSON(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create malformed config file
|
||||
err := helper.CreateMalformedConfigFile("configuration.json")
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// Test GetConfiguration with malformed JSON
|
||||
config, err := configService.GetConfiguration(helper.TestData.Server)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for malformed JSON, got nil")
|
||||
@@ -232,31 +196,24 @@ func TestConfigService_MalformedJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_UTF16_Encoding(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Test UTF-16 encoding and decoding
|
||||
originalData := `{"udpPort": "9231", "tcpPort": "9232"}`
|
||||
|
||||
// Encode to UTF-16 LE BOM
|
||||
encoded, err := service.EncodeUTF16LEBOM([]byte(originalData))
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Decode back to UTF-8
|
||||
decoded, err := service.DecodeUTF16LEBOM(encoded)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify it matches original
|
||||
tests.AssertEqual(t, originalData, string(decoded))
|
||||
}
|
||||
|
||||
func TestConfigService_DecodeFileName(t *testing.T) {
|
||||
// Test that all supported file names have decoders
|
||||
testCases := []string{
|
||||
"configuration.json",
|
||||
"assistRules.json",
|
||||
@@ -272,7 +229,6 @@ func TestConfigService_DecodeFileName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// Test invalid filename
|
||||
decoder := service.DecodeFileName("invalid.json")
|
||||
if decoder != nil {
|
||||
t.Fatal("Expected nil decoder for invalid filename, got non-nil")
|
||||
@@ -280,22 +236,18 @@ func TestConfigService_DecodeFileName(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_IntString_Conversion(t *testing.T) {
|
||||
// Test IntString unmarshaling from string
|
||||
var intStr model.IntString
|
||||
|
||||
// Test string input
|
||||
err := json.Unmarshal([]byte(`"123"`), &intStr)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 123, intStr.ToInt())
|
||||
tests.AssertEqual(t, "123", intStr.ToString())
|
||||
|
||||
// Test int input
|
||||
err = json.Unmarshal([]byte(`456`), &intStr)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 456, intStr.ToInt())
|
||||
tests.AssertEqual(t, "456", intStr.ToString())
|
||||
|
||||
// Test empty string
|
||||
err = json.Unmarshal([]byte(`""`), &intStr)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 0, intStr.ToInt())
|
||||
@@ -303,28 +255,23 @@ func TestConfigService_IntString_Conversion(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_IntBool_Conversion(t *testing.T) {
|
||||
// Test IntBool unmarshaling from int
|
||||
var intBool model.IntBool
|
||||
|
||||
// Test int input (1 = true)
|
||||
err := json.Unmarshal([]byte(`1`), &intBool)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 1, intBool.ToInt())
|
||||
tests.AssertEqual(t, true, intBool.ToBool())
|
||||
|
||||
// Test int input (0 = false)
|
||||
err = json.Unmarshal([]byte(`0`), &intBool)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 0, intBool.ToInt())
|
||||
tests.AssertEqual(t, false, intBool.ToBool())
|
||||
|
||||
// Test bool input (true)
|
||||
err = json.Unmarshal([]byte(`true`), &intBool)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 1, intBool.ToInt())
|
||||
tests.AssertEqual(t, true, intBool.ToBool())
|
||||
|
||||
// Test bool input (false)
|
||||
err = json.Unmarshal([]byte(`false`), &intBool)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 0, intBool.ToInt())
|
||||
@@ -332,25 +279,20 @@ func TestConfigService_IntBool_Conversion(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigService_Caching_Configuration(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files (already UTF-16 encoded)
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// First call - should load from disk
|
||||
config1, err := configService.GetConfiguration(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, config1)
|
||||
|
||||
// Modify the file on disk with UTF-16 encoding
|
||||
configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json")
|
||||
modifiedContent := `{"udpPort": "5555", "tcpPort": "5556"}`
|
||||
utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent))
|
||||
@@ -359,36 +301,29 @@ func TestConfigService_Caching_Configuration(t *testing.T) {
|
||||
err = os.WriteFile(configPath, utf16Modified, 0644)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Second call - should return cached result (not the modified file)
|
||||
config2, err := configService.GetConfiguration(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, config2)
|
||||
|
||||
// Should still have the original cached values
|
||||
tests.AssertEqual(t, model.IntString(9231), config2.UdpPort)
|
||||
tests.AssertEqual(t, model.IntString(9232), config2.TcpPort)
|
||||
}
|
||||
|
||||
func TestConfigService_Caching_EventConfig(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test config files (already UTF-16 encoded)
|
||||
err := helper.CreateTestConfigFiles()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create repositories and service
|
||||
configRepo := repository.NewConfigRepository(helper.DB)
|
||||
serverRepo := repository.NewServerRepository(helper.DB)
|
||||
configService := service.NewConfigService(configRepo, serverRepo)
|
||||
|
||||
// First call - should load from disk
|
||||
event1, err := configService.GetEventConfig(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, event1)
|
||||
|
||||
// Modify the file on disk with UTF-16 encoding
|
||||
configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "event.json")
|
||||
modifiedContent := `{"track": "monza", "preRaceWaitingTimeSeconds": "60"}`
|
||||
utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent))
|
||||
@@ -397,12 +332,10 @@ func TestConfigService_Caching_EventConfig(t *testing.T) {
|
||||
err = os.WriteFile(configPath, utf16Modified, 0644)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Second call - should return cached result (not the modified file)
|
||||
event2, err := configService.GetEventConfig(helper.TestData.Server)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, event2)
|
||||
|
||||
// Should still have the original cached values
|
||||
tests.AssertEqual(t, "spa", event2.Track)
|
||||
tests.AssertEqual(t, model.IntString(80), event2.PreRaceWaitingTimeSeconds)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,9 @@ import (
|
||||
)
|
||||
|
||||
func TestStateHistoryService_GetAll_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -27,22 +25,17 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Use real repository like other service tests
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Insert test data directly into DB
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetAll
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
result, err := stateHistoryService.GetAll(ctx, filter)
|
||||
|
||||
@@ -54,11 +47,9 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -69,7 +60,6 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Insert test data with different sessions
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 10, uuid.New())
|
||||
@@ -79,12 +69,10 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
|
||||
err = repo.Insert(helper.CreateContext(), &raceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetAll with session filter
|
||||
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace)
|
||||
result, err := stateHistoryService.GetAll(ctx, filter)
|
||||
|
||||
@@ -96,11 +84,9 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_GetAll_NoData(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -111,12 +97,10 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetAll with no data
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
result, err := stateHistoryService.GetAll(ctx, filter)
|
||||
|
||||
@@ -126,11 +110,9 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_Insert_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -141,20 +123,16 @@ func TestStateHistoryService_Insert_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test Insert
|
||||
err := stateHistoryService.Insert(ctx, &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify data was inserted
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
result, err := stateHistoryService.GetAll(ctx, filter)
|
||||
tests.AssertNoError(t, err)
|
||||
@@ -162,11 +140,9 @@ func TestStateHistoryService_Insert_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -177,30 +153,25 @@ func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
sessionID := uuid.New()
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID)
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetLastSessionID
|
||||
lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, sessionID, lastSessionID)
|
||||
}
|
||||
|
||||
func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -211,26 +182,21 @@ func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetLastSessionID with no data
|
||||
lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, uuid.Nil, lastSessionID)
|
||||
}
|
||||
|
||||
func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
|
||||
// This test might fail due to database setup issues
|
||||
t.Skip("Skipping test as it's dependent on database migration")
|
||||
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -241,10 +207,8 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Insert test data with varying player counts
|
||||
_ = testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Create entries with different sessions and player counts
|
||||
sessionID1 := uuid.New()
|
||||
sessionID2 := uuid.New()
|
||||
|
||||
@@ -291,12 +255,10 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetStatistics
|
||||
filter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: helper.TestData.ServerID.String(),
|
||||
@@ -311,17 +273,14 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, stats)
|
||||
|
||||
// Verify statistics
|
||||
tests.AssertEqual(t, 15, stats.PeakPlayers) // Maximum player count
|
||||
tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions
|
||||
tests.AssertEqual(t, 15, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 2, stats.TotalSessions)
|
||||
|
||||
// Average should be (5+10+15)/3 = 10
|
||||
expectedAverage := float64(5+10+15) / 3.0
|
||||
if stats.AveragePlayers != expectedAverage {
|
||||
t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers)
|
||||
}
|
||||
|
||||
// Verify other statistics components exist
|
||||
tests.AssertNotNil(t, stats.PlayerCountOverTime)
|
||||
tests.AssertNotNil(t, stats.SessionTypes)
|
||||
tests.AssertNotNil(t, stats.DailyActivity)
|
||||
@@ -329,14 +288,11 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_GetStatistics_NoData(t *testing.T) {
|
||||
// This test might fail due to database setup issues
|
||||
t.Skip("Skipping test as it's dependent on database migration")
|
||||
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -347,19 +303,16 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
// Create proper Fiber context
|
||||
app := fiber.New()
|
||||
ctx := helper.CreateFiberCtx()
|
||||
defer helper.ReleaseFiberCtx(app, ctx)
|
||||
|
||||
// Test GetStatistics with no data
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
stats, err := stateHistoryService.GetStatistics(ctx, filter)
|
||||
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, stats)
|
||||
|
||||
// Verify empty statistics
|
||||
tests.AssertEqual(t, 0, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
|
||||
tests.AssertEqual(t, 0, stats.TotalSessions)
|
||||
@@ -367,43 +320,32 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_LogParsingWorkflow(t *testing.T) {
|
||||
// Skip this test as it's unreliable and not critical
|
||||
t.Skip("Skipping log parsing test as it's not critical to the service functionality")
|
||||
|
||||
// This test simulates the actual log parsing workflow
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Insert test server
|
||||
err := helper.InsertTestServer()
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
server := helper.TestData.Server
|
||||
|
||||
// Track state changes
|
||||
var stateChanges []*model.ServerState
|
||||
onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) {
|
||||
// Use pointer to avoid copying mutex
|
||||
stateChanges = append(stateChanges, state)
|
||||
}
|
||||
|
||||
// Create AccServerInstance (this is what the real server service does)
|
||||
instance := tracking.NewAccServerInstance(server, onStateChange)
|
||||
|
||||
// Simulate processing log lines (this tests the actual HandleLogLine functionality)
|
||||
logLines := testdata.SampleLogLines
|
||||
|
||||
for _, line := range logLines {
|
||||
instance.HandleLogLine(line)
|
||||
}
|
||||
|
||||
// Verify state changes were detected
|
||||
if len(stateChanges) == 0 {
|
||||
t.Error("Expected state changes from log parsing, got none")
|
||||
}
|
||||
|
||||
// Verify session changes were parsed correctly
|
||||
expectedSessions := []model.TrackSession{model.SessionPractice, model.SessionQualify, model.SessionRace}
|
||||
sessionIndex := 0
|
||||
|
||||
@@ -416,18 +358,15 @@ func TestStateHistoryService_LogParsingWorkflow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify player count changes were tracked
|
||||
if len(stateChanges) > 0 {
|
||||
finalState := stateChanges[len(stateChanges)-1]
|
||||
tests.AssertEqual(t, 0, finalState.PlayerCount) // Should end with 0 players
|
||||
tests.AssertEqual(t, 0, finalState.PlayerCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
|
||||
// Skip this test as it's unreliable
|
||||
t.Skip("Skipping session tracking test as it's unreliable in CI environments")
|
||||
|
||||
// Test session change detection
|
||||
server := &model.Server{
|
||||
ID: uuid.New(),
|
||||
Name: "Test Server",
|
||||
@@ -437,7 +376,6 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
|
||||
onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) {
|
||||
for _, change := range changes {
|
||||
if change == tracking.Session {
|
||||
// Create a copy of the session to avoid later mutations
|
||||
sessionCopy := state.Session
|
||||
sessionChanges = append(sessionChanges, sessionCopy)
|
||||
}
|
||||
@@ -446,22 +384,17 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
|
||||
|
||||
instance := tracking.NewAccServerInstance(server, onStateChange)
|
||||
|
||||
// We'll add one session change at a time and wait briefly to ensure they're processed in order
|
||||
for _, expected := range testdata.ExpectedSessionChanges {
|
||||
line := string("[2024-01-15 14:30:25.123] Session changed: " + expected.From + " -> " + expected.To)
|
||||
instance.HandleLogLine(line)
|
||||
// Small pause to ensure log processing completes
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Check if we have any session changes
|
||||
if len(sessionChanges) == 0 {
|
||||
t.Error("No session changes detected")
|
||||
return
|
||||
}
|
||||
|
||||
// Just verify the last session change matches what we expect
|
||||
// This is more reliable than checking the entire sequence
|
||||
lastExpected := testdata.ExpectedSessionChanges[len(testdata.ExpectedSessionChanges)-1].To
|
||||
lastActual := sessionChanges[len(sessionChanges)-1]
|
||||
if lastActual != lastExpected {
|
||||
@@ -470,10 +403,8 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
|
||||
// Skip this test as it's unreliable
|
||||
t.Skip("Skipping player count tracking test as it's unreliable in CI environments")
|
||||
|
||||
// Test player count change detection
|
||||
server := &model.Server{
|
||||
ID: uuid.New(),
|
||||
Name: "Test Server",
|
||||
@@ -490,7 +421,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
|
||||
|
||||
instance := tracking.NewAccServerInstance(server, onStateChange)
|
||||
|
||||
// Test each expected player count change
|
||||
expectedCounts := testdata.ExpectedPlayerCounts
|
||||
logLines := []string{
|
||||
"[2024-01-15 14:30:30.456] 1 client(s) online",
|
||||
@@ -499,7 +429,7 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
|
||||
"[2024-01-15 14:35:05.789] 8 client(s) online",
|
||||
"[2024-01-15 14:40:05.456] 12 client(s) online",
|
||||
"[2024-01-15 14:45:00.789] 15 client(s) online",
|
||||
"[2024-01-15 14:50:00.789] Removing dead connection", // Should decrease by 1
|
||||
"[2024-01-15 14:50:00.789] Removing dead connection",
|
||||
"[2024-01-15 15:00:00.789] 0 client(s) online",
|
||||
}
|
||||
|
||||
@@ -507,7 +437,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
|
||||
instance.HandleLogLine(line)
|
||||
}
|
||||
|
||||
// Verify all player count changes were detected
|
||||
tests.AssertEqual(t, len(expectedCounts), len(playerCounts))
|
||||
for i, expected := range expectedCounts {
|
||||
if i < len(playerCounts) {
|
||||
@@ -517,10 +446,8 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_EdgeCases(t *testing.T) {
|
||||
// Skip this test as it's unreliable
|
||||
t.Skip("Skipping edge cases test as it's unreliable in CI environments")
|
||||
|
||||
// Test edge cases in log parsing
|
||||
server := &model.Server{
|
||||
ID: uuid.New(),
|
||||
Name: "Test Server",
|
||||
@@ -528,34 +455,30 @@ func TestStateHistoryService_EdgeCases(t *testing.T) {
|
||||
|
||||
var stateChanges []*model.ServerState
|
||||
onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) {
|
||||
// Create a copy of the state to avoid later mutations affecting our saved state
|
||||
stateCopy := *state
|
||||
stateChanges = append(stateChanges, &stateCopy)
|
||||
}
|
||||
|
||||
instance := tracking.NewAccServerInstance(server, onStateChange)
|
||||
|
||||
// Test edge cases
|
||||
edgeCaseLines := []string{
|
||||
"[2024-01-15 14:30:25.123] Some unrelated log line", // Should be ignored
|
||||
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", // Valid session change
|
||||
"[2024-01-15 14:30:30.456] 0 client(s) online", // Zero players
|
||||
"[2024-01-15 14:30:35.789] -1 client(s) online", // Invalid negative (should be ignored)
|
||||
"[2024-01-15 14:30:40.789] 30 client(s) online", // High but valid player count
|
||||
"[2024-01-15 14:30:45.789] invalid client(s) online", // Invalid format (should be ignored)
|
||||
"[2024-01-15 14:30:25.123] Some unrelated log line",
|
||||
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE",
|
||||
"[2024-01-15 14:30:30.456] 0 client(s) online",
|
||||
"[2024-01-15 14:30:35.789] -1 client(s) online",
|
||||
"[2024-01-15 14:30:40.789] 30 client(s) online",
|
||||
"[2024-01-15 14:30:45.789] invalid client(s) online",
|
||||
}
|
||||
|
||||
for _, line := range edgeCaseLines {
|
||||
instance.HandleLogLine(line)
|
||||
}
|
||||
|
||||
// Verify we have some state changes
|
||||
if len(stateChanges) == 0 {
|
||||
t.Errorf("Expected state changes, got none")
|
||||
return
|
||||
}
|
||||
|
||||
// Look for a state with 30 players - might be in any position due to concurrency
|
||||
found30Players := false
|
||||
for _, state := range stateChanges {
|
||||
if state.PlayerCount == 30 {
|
||||
@@ -564,8 +487,6 @@ func TestStateHistoryService_EdgeCases(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark the test as passed if we found at least one state with the expected value
|
||||
// This makes the test more resilient to timing/ordering differences
|
||||
if !found30Players {
|
||||
t.Log("Player counts in recorded states:")
|
||||
for i, state := range stateChanges {
|
||||
@@ -576,10 +497,8 @@ func TestStateHistoryService_EdgeCases(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryService_SessionStartTracking(t *testing.T) {
|
||||
// Skip this test as it's unreliable
|
||||
t.Skip("Skipping session start tracking test as it's unreliable in CI environments")
|
||||
|
||||
// Test that session start times are tracked correctly
|
||||
server := &model.Server{
|
||||
ID: uuid.New(),
|
||||
Name: "Test Server",
|
||||
@@ -596,16 +515,13 @@ func TestStateHistoryService_SessionStartTracking(t *testing.T) {
|
||||
|
||||
instance := tracking.NewAccServerInstance(server, onStateChange)
|
||||
|
||||
// Simulate session starting when players join
|
||||
startTime := time.Now()
|
||||
instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online") // First player joins
|
||||
instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online")
|
||||
|
||||
// Verify session start was recorded
|
||||
if len(sessionStarts) == 0 {
|
||||
t.Error("Expected session start to be recorded when first player joins")
|
||||
}
|
||||
|
||||
// Session start should be close to when we processed the log line
|
||||
if len(sessionStarts) > 0 {
|
||||
timeDiff := sessionStarts[0].Sub(startTime)
|
||||
if timeDiff > time.Second || timeDiff < -time.Second {
|
||||
|
||||
Reference in New Issue
Block a user