From 5e7c96697aa5720699b4fc12baf6a846d33113d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=20Jurmanovi=C4=87?= Date: Thu, 18 Sep 2025 13:33:51 +0200 Subject: [PATCH] code cleanup --- .claude/settings.local.json | 5 + CLAUDE.md | 121 + frontend.md | 1986 +++++++++++++++++ local/controller/server.go | 36 - local/controller/websocket.go | 32 +- local/middleware/access_key.go | 4 - local/middleware/auth.go | 29 - local/middleware/logging/request_logging.go | 12 - local/middleware/security/security.go | 48 +- .../001_upgrade_password_security.go | 41 +- local/migrations/002_migrate_to_uuid.go | 21 - .../003_update_state_history_sessions.go | 18 - local/model/cache.go | 38 +- local/model/config.go | 6 +- local/model/filter.go | 18 +- local/model/lookup.go | 5 - local/model/model.go | 2 - local/model/permission.go | 4 +- local/model/permissions.go | 2 - local/model/role.go | 4 +- local/model/server.go | 30 +- local/model/service_control.go | 10 - local/model/state_history.go | 14 +- local/model/state_history_stats.go | 2 +- local/model/steam_credentials.go | 29 +- local/model/user.go | 12 - local/model/websocket.go | 9 - local/repository/base.go | 13 - local/repository/config.go | 4 +- local/repository/membership.go | 19 - local/repository/repository.go | 14 +- local/repository/server.go | 4 - local/repository/state_history.go | 17 +- local/service/config.go | 34 +- local/service/firewall_service.go | 4 +- local/service/membership.go | 38 +- local/service/server.go | 153 +- local/service/service.go | 7 +- local/service/service_control.go | 19 +- local/service/service_manager.go | 11 +- local/service/state_history.go | 6 - local/service/steam_service.go | 298 ++- local/service/websocket.go | 39 +- local/service/windows_service.go | 21 - local/utl/cache/cache.go | 12 - local/utl/command/callback_executor.go | 245 ++ local/utl/command/callbacks.go | 19 + local/utl/command/executor.go | 21 +- local/utl/command/interactive_executor.go | 300 +-- local/utl/common/common.go | 15 +- local/utl/configs/configs.go | 6 +- local/utl/db/db.go | 5 - local/utl/env/env.go | 7 - .../error_handler/controller_error_handler.go | 45 +- local/utl/jwt/jwt.go | 9 - local/utl/logging/base.go | 21 - local/utl/logging/debug.go | 23 - local/utl/logging/error.go | 14 - local/utl/logging/info.go | 20 - local/utl/logging/logger.go | 31 - local/utl/logging/warn.go | 14 - local/utl/network/port.go | 9 +- local/utl/password/password.go | 7 +- local/utl/server/server.go | 20 +- local/utl/tracking/log_tailer.go | 24 +- local/utl/tracking/position_tracker.go | 10 +- local/utl/tracking/tracking.go | 4 +- local/utl/websocket/websocket.go | 169 ++ swagger/docs.go | 2 - tests/auth_helper.go | 15 - tests/mocks/auth_middleware_mock.go | 11 +- tests/mocks/repository_mock.go | 13 - tests/mocks/state_history_mock.go | 39 +- tests/test_helper.go | 51 +- tests/testdata/state_history_data.go | 9 - .../unit/controller/controller_simple_test.go | 70 +- tests/unit/controller/helper.go | 13 +- .../state_history_controller_test.go | 96 +- .../state_history_repository_test.go | 71 +- tests/unit/service/auth_simple_test.go | 46 +- tests/unit/service/cache_service_test.go | 120 +- tests/unit/service/config_service_test.go | 67 - .../service/state_history_service_test.go | 106 +- 83 files changed, 2832 insertions(+), 2186 deletions(-) create mode 100644 .claude/settings.local.json create mode 100644 CLAUDE.md create mode 100644 frontend.md create mode 100644 local/utl/command/callback_executor.go create mode 100644 local/utl/command/callbacks.go create mode 100644 local/utl/websocket/websocket.go diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..aba839c --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,5 @@ +{ + "permissions": { + "defaultMode": "acceptEdits" + } +} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..079a4d6 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,121 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +### Building and Running +```bash +# Build the main application +go build -o api.exe cmd/api/main.go + +# Build with hot reload (requires air) +go install github.com/cosmtrek/air@latest +air + +# Run the built binary +./api.exe + +# Build migration utility +go build -o acc-server-migration.exe cmd/migrate/main.go +``` + +### Testing +```bash +# Run all tests +go test ./... + +# Run tests with verbose output +go test -v ./... + +# Run specific test package +go test ./tests/unit/service/ +go test ./tests/unit/controller/ +go test ./tests/unit/repository/ +``` + +### Documentation +```bash +# Generate Swagger documentation (if swag is installed) +swag init -g cmd/api/main.go +``` + +### Setup and Configuration +```bash +# Generate security configuration (Windows PowerShell) +.\scripts\generate-secrets.ps1 + +# Deploy (requires configuration) +.\scripts\deploy.ps1 +``` + +## Architecture Overview + +This is a Go-based web application for managing Assetto Corsa Competizione (ACC) dedicated servers on Windows. The architecture follows a layered approach: + +### Core Layers +- **cmd/**: Application entry points (main.go for API server, migrate/main.go for migrations) +- **local/**: Core application code organized by architectural layer + - **api/**: HTTP route definitions and API setup + - **controller/**: HTTP request handlers (config, membership, server, service_control, steam_2fa, system) + - **service/**: Business logic layer (server management, Steam integration, Windows services) + - **repository/**: Data access layer with GORM ORM + - **model/**: Data models and structures + - **middleware/**: HTTP middleware (auth, security, logging) + - **utl/**: Utilities organized by function (cache, command execution, JWT, logging, etc.) + +### Key Components + +#### Dependency Injection +Uses `go.uber.org/dig` for dependency injection. Main dependencies are set up in `cmd/api/main.go`. + +#### Database +- SQLite with GORM ORM +- Database migrations in `local/migrations/` +- Models support UUID primary keys + +#### Authentication & Security +- JWT-based authentication with two token types (regular and "open") +- Comprehensive security middleware stack including rate limiting, input sanitization, CORS +- Encrypted credential storage for Steam integration + +#### Server Management +- Windows service integration via NSSM +- Steam integration for server installation/updates via SteamCMD +- Interactive command execution for Steam 2FA +- Firewall management +- Configuration file generation and management + +#### Logging +- Custom logging system with multiple levels (debug, info, warn, error) +- Request logging middleware +- Structured logging with categories + +### Testing Structure +- Unit tests in `tests/unit/` organized by layer (controller, service, repository) +- Test helpers and mocks in `tests/` directory +- Uses standard Go testing with mocks for external dependencies + +### External Dependencies +- **Fiber v2**: Web framework +- **GORM**: ORM for database operations +- **SteamCMD**: External tool for Steam server management (configured via STEAMCMD_PATH env var) +- **NSSM**: Windows service management (configured via NSSM_PATH env var) + +### Configuration +- Environment variables for external tool paths and configuration +- JWT secrets generated via setup scripts +- CORS configuration with configurable allowed origins +- Default port 3000 (configurable via PORT env var) + +## Important Notes + +### Windows-Specific Features +This application is designed specifically for Windows and includes: +- Windows service management integration +- PowerShell script execution +- Windows-specific path handling +- Firewall rule management + +### Steam Integration +The Steam 2FA implementation (`local/controller/steam_2fa.go`, `local/model/steam_2fa.go`) provides interactive Steam authentication for automated server management. \ No newline at end of file diff --git a/frontend.md b/frontend.md new file mode 100644 index 0000000..94039f2 --- /dev/null +++ b/frontend.md @@ -0,0 +1,1986 @@ +# Frontend Steam 2FA Implementation Guide + +## Overview + +This document provides complete implementation details for the frontend Steam 2FA system. The backend Steam 2FA implementation is complete and waiting for frontend integration. This guide includes all necessary code, components, and integration instructions. + +## Problem Statement + +The current Steam 2FA backend implementation (`local/controller/steam_2fa.go`, `local/model/steam_2fa.go`) provides REST API endpoints but has no frontend to: +1. Poll for pending 2FA requests +2. Display 2FA prompts to users +3. Allow users to confirm/cancel 2FA operations + +Without frontend implementation, Steam operations requiring 2FA will hang indefinitely until timeout (5 minutes). + +## Architecture Overview + +### Backend API Endpoints (Already Implemented) +- `GET /v1/steam2fa/pending` - Returns array of pending 2FA requests +- `GET /v1/steam2fa/{id}` - Gets specific 2FA request details +- `POST /v1/steam2fa/{id}/complete` - Marks 2FA request as completed +- `POST /v1/steam2fa/{id}/cancel` - Cancels 2FA request + +### Frontend Components (To Be Implemented) + +``` +src/ +├── models/ +│ └── steam2fa.ts # TypeScript interfaces +├── stores/ +│ └── steam2fa.ts # Svelte store for state management +├── components/ +│ └── Steam2FANotification.svelte # Modal component +└── lib/ + └── api/ + └── steam2fa.ts # API client functions +``` + +## Implementation Details + +### 1. TypeScript Interfaces (`src/models/steam2fa.ts`) + +```typescript +export type Steam2FAStatus = 'idle' | 'pending' | 'complete' | 'error'; + +export interface Steam2FARequest { + id: string; + status: Steam2FAStatus; + message: string; + requestTime: string; // ISO 8601 timestamp + completedAt?: string; // ISO 8601 timestamp + errorMsg?: string; + serverId?: string; // UUID +} + +export interface Steam2FAStore { + requests: Steam2FARequest[]; + isPolling: boolean; + error: string | null; + lastChecked: Date | null; +} +``` + +### 2. API Client (`src/lib/api/steam2fa.ts`) + +```typescript +import type { Steam2FARequest } from '../models/steam2fa'; + +const API_BASE = '/v1/steam2fa'; + +export class Steam2FAApi { + /** + * Get all pending 2FA requests + */ + static async getPendingRequests(): Promise { + const response = await fetch(`${API_BASE}/pending`, { + method: 'GET', + credentials: 'include', // Include auth cookies + }); + + if (!response.ok) { + throw new Error(`Failed to fetch pending requests: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Get specific 2FA request by ID + */ + static async getRequest(id: string): Promise { + const response = await fetch(`${API_BASE}/${id}`, { + method: 'GET', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to fetch request ${id}: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Mark 2FA request as completed + */ + static async completeRequest(id: string): Promise { + const response = await fetch(`${API_BASE}/${id}/complete`, { + method: 'POST', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to complete request ${id}: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Cancel 2FA request + */ + static async cancelRequest(id: string): Promise { + const response = await fetch(`${API_BASE}/${id}/cancel`, { + method: 'POST', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to cancel request ${id}: ${response.statusText}`); + } + + return response.json(); + } +} +``` + +### 3. Svelte Store (`src/stores/steam2fa.ts`) + +```typescript +import { writable, derived } from 'svelte/store'; +import type { Steam2FARequest, Steam2FAStore } from '../models/steam2fa'; +import { Steam2FAApi } from '../lib/api/steam2fa'; + +const POLLING_INTERVAL = 5000; // 5 seconds +const MAX_RETRIES = 3; + +// Create the main store +function createSteam2FAStore() { + const { subscribe, set, update } = writable({ + requests: [], + isPolling: false, + error: null, + lastChecked: null, + }); + + let pollingInterval: NodeJS.Timeout | null = null; + let retryCount = 0; + + const startPolling = () => { + if (pollingInterval) return; // Already polling + + update(store => ({ ...store, isPolling: true, error: null })); + + const poll = async () => { + try { + const requests = await Steam2FAApi.getPendingRequests(); + + update(store => ({ + ...store, + requests, + error: null, + lastChecked: new Date(), + })); + + retryCount = 0; // Reset retry count on success + } catch (error) { + console.error('Steam 2FA polling error:', error); + retryCount++; + + if (retryCount >= MAX_RETRIES) { + update(store => ({ + ...store, + error: `Failed to check for 2FA requests: ${error.message}`, + isPolling: false, + })); + stopPolling(); + return; + } + + update(store => ({ + ...store, + error: `Connection issue (retry ${retryCount}/${MAX_RETRIES})`, + })); + } + }; + + // Poll immediately, then set interval + poll(); + pollingInterval = setInterval(poll, POLLING_INTERVAL); + }; + + const stopPolling = () => { + if (pollingInterval) { + clearInterval(pollingInterval); + pollingInterval = null; + } + update(store => ({ ...store, isPolling: false })); + }; + + const completeRequest = async (id: string) => { + try { + await Steam2FAApi.completeRequest(id); + + // Remove the completed request from store + update(store => ({ + ...store, + requests: store.requests.filter(req => req.id !== id), + })); + } catch (error) { + console.error('Failed to complete 2FA request:', error); + update(store => ({ + ...store, + error: `Failed to complete 2FA: ${error.message}`, + })); + throw error; + } + }; + + const cancelRequest = async (id: string) => { + try { + await Steam2FAApi.cancelRequest(id); + + // Remove the cancelled request from store + update(store => ({ + ...store, + requests: store.requests.filter(req => req.id !== id), + })); + } catch (error) { + console.error('Failed to cancel 2FA request:', error); + update(store => ({ + ...store, + error: `Failed to cancel 2FA: ${error.message}`, + })); + throw error; + } + }; + + const clearError = () => { + update(store => ({ ...store, error: null })); + }; + + return { + subscribe, + startPolling, + stopPolling, + completeRequest, + cancelRequest, + clearError, + }; +} + +export const steam2fa = createSteam2FAStore(); + +// Derived store for pending requests +export const pendingRequests = derived( + steam2fa, + $steam2fa => $steam2fa.requests.filter(req => req.status === 'pending') +); + +// Derived store to check if any requests are pending +export const hasPendingRequests = derived( + pendingRequests, + $pendingRequests => $pendingRequests.length > 0 +); +``` + +### 4. Modal Component (`src/components/Steam2FANotification.svelte`) + +```svelte + + +{#if isVisible && currentRequest} + + +{/if} + + +``` + +### 5. Main App Integration + +Add to your main layout file (e.g., `src/routes/+layout.svelte`): + +```svelte + + + +
+ + +
+ + + +``` + +## Integration Steps + +### 1. Install Dependencies + +Ensure your frontend has these dependencies (likely already present): + +```json +{ + "devDependencies": { + "@types/node": "^20.x.x" + } +} +``` + +### 2. Create Files + +Create all the files listed above in their respective directories in your frontend repository. + +### 3. Update Main Layout + +Add the Steam2FANotification component and polling initialization to your main layout. + +### 4. Configure API Base URL + +Ensure your API calls are pointing to the correct backend URL. Update the `API_BASE` constant in `steam2fa.ts` if needed: + +```typescript +const API_BASE = '/v1/steam2fa'; // Adjust if your backend uses different base URL +``` + +### 5. Test Authentication + +Ensure your frontend sends authentication cookies/headers with API requests. The backend requires authentication for all 2FA endpoints. + +## Flow Diagram + +``` +SteamCMD Operation Started + ↓ +Backend detects 2FA prompt + ↓ +Steam2FARequest created with "pending" status + ↓ +Frontend polling detects pending request + ↓ +Modal appears automatically + ↓ +User checks Steam Mobile App + ↓ +User approves login in Steam app + ↓ +User clicks "I've Confirmed" button + ↓ +Frontend calls POST /v1/steam2fa/{id}/complete + ↓ +Backend marks request as complete + ↓ +SteamCMD operation continues + ↓ +Frontend polling finds no pending requests + ↓ +Modal closes automatically +``` + +## Error Handling + +### Frontend Error Scenarios +1. **Network Errors**: Displays retry count and stops after 3 failures +2. **API Errors**: Shows specific error messages to user +3. **Timeout**: Backend handles 5-minute timeout, frontend shows appropriate message +4. **Multiple Requests**: Shows first pending request, handles queue automatically + +### User Experience +- **Immediate Feedback**: Modal appears as soon as 2FA is required +- **Clear Instructions**: Step-by-step guide for users +- **Error Recovery**: Clear error messages with retry options +- **Non-Blocking**: User can cancel if needed + +## Testing Instructions + +### Manual Testing +1. **Setup**: Ensure Steam account has 2FA enabled +2. **Trigger**: Create/update a server to trigger SteamCMD +3. **Verify Modal**: Check that modal appears when 2FA prompt occurs +4. **Test Success**: Approve in Steam app, click "I've Confirmed" +5. **Test Cancel**: Try canceling a 2FA request +6. **Test Errors**: Disconnect network during operation + +### Debugging +```javascript +// Check store state in browser console +console.log($steam2fa); + +// Manual API testing +Steam2FAApi.getPendingRequests().then(console.log); +``` + +## Security Considerations + +1. **Authentication**: All API calls include credentials +2. **No Sensitive Data**: No Steam credentials exposed +3. **Timeout Protection**: 5-minute backend timeout prevents hanging +4. **Request Validation**: Backend validates request ownership + +## Performance Considerations + +1. **Polling Frequency**: 5-second interval balances responsiveness with load +2. **Automatic Cleanup**: Backend cleans up old requests after 30 minutes +3. **Efficient Updates**: Reactive stores minimize re-renders +4. **Error Backoff**: Stops polling after repeated failures + +## Future Enhancements + +1. **WebSocket Support**: Real-time updates instead of polling +2. **Browser Notifications**: Alert users even when tab is not active +3. **Multiple Request Support**: Handle multiple simultaneous 2FA requests +4. **Enhanced Prompt Detection**: More sophisticated Steam output parsing + +## Troubleshooting + +### Common Issues + +#### Modal Doesn't Appear +- Check browser console for JavaScript errors +- Verify API connectivity: `fetch('/v1/steam2fa/pending')` +- Ensure user is authenticated +- Check backend logs for 2FA request creation + +#### API Errors +- Verify authentication cookies are sent +- Check CORS configuration +- Ensure backend is running and accessible +- Review backend error logs + +#### SteamCMD Hangs +- Check if 2FA request was created in backend logs +- Verify Steam Mobile App connectivity +- Look for timeout errors after 5 minutes + +### Debug Commands + +```bash +# Check backend logs for 2FA events +grep -i "2fa" logs/app.log + +# Monitor API requests +tail -f logs/app.log | grep "steam2fa" +``` + +--- + +# Server Management Pages Implementation + +## Overview + +In addition to the Steam 2FA system, the frontend needs complete server management pages for creating, viewing, updating, and deleting ACC servers. This section provides implementation details for all server management functionality. + +## Backend API Endpoints (Already Implemented) + +### Server CRUD Operations +- `GET /v1/server` - List all servers with filtering +- `GET /v1/api/server` - List servers in API format (simplified) +- `GET /v1/server/{id}` - Get specific server details +- `POST /v1/server` - Create new server +- `PUT /v1/server/{id}` - Update existing server +- `DELETE /v1/server/{id}` - Delete server + +### Service Control +- `GET /v1/server/{id}/service/{service}` - Get service status +- `POST /v1/server/{id}/service/start` - Start server service +- `POST /v1/server/{id}/service/stop` - Stop server service +- `POST /v1/server/{id}/service/restart` - Restart server service + +### Configuration Management +- `GET /v1/server/{id}/config` - List server config files +- `GET /v1/server/{id}/config/{file}` - Get specific config file +- `PUT /v1/server/{id}/config/{file}` - Update config file + +### Authentication +- `POST /v1/auth/login` - User login +- `POST /v1/auth/open-token` - Generate open token (for service calls) +- `GET /v1/auth/me` - Get current user info + +## Frontend Implementation + +### 1. TypeScript Models (`src/models/server.ts`) + +```typescript +export type ServiceStatus = 'running' | 'stopped' | 'starting' | 'stopping' | 'unknown'; + +export interface ServerState { + session: string; + sessionStart: string; // ISO timestamp + playerCount: number; + track: string; + maxConnections: number; + sessionDurationMinutes: number; +} + +export interface ServerAPI { + name: string; + status: ServiceStatus; + state: ServerState | null; + playerCount: number; + track: string; +} + +export interface Server { + id: string; // UUID + name: string; + status: ServiceStatus; + path: string; + serviceName: string; + state: ServerState | null; + dateCreated: string; // ISO timestamp +} + +export interface ServerFilter { + name?: string; + serviceName?: string; + status?: string; + serverID?: string; + limit?: number; + offset?: number; +} + +export interface CreateServerRequest { + name: string; + // Additional server configuration fields as needed +} + +export interface UpdateServerRequest { + name: string; + // Other updateable fields +} +``` + +### 2. API Client (`src/lib/api/server.ts`) + +```typescript +import type { Server, ServerAPI, ServerFilter, CreateServerRequest, UpdateServerRequest } from '../models/server'; + +const API_BASE = '/v1'; + +export class ServerApi { + /** + * Get all servers with optional filtering + */ + static async getServers(filter?: ServerFilter): Promise { + const params = new URLSearchParams(); + if (filter) { + Object.entries(filter).forEach(([key, value]) => { + if (value !== undefined && value !== null) { + params.append(key, value.toString()); + } + }); + } + + const url = `${API_BASE}/server${params.toString() ? `?${params.toString()}` : ''}`; + const response = await fetch(url, { + method: 'GET', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to fetch servers: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Get servers in API format (simplified) + */ + static async getServersAPI(filter?: ServerFilter): Promise { + const params = new URLSearchParams(); + if (filter) { + Object.entries(filter).forEach(([key, value]) => { + if (value !== undefined && value !== null) { + params.append(key, value.toString()); + } + }); + } + + const url = `${API_BASE}/api/server${params.toString() ? `?${params.toString()}` : ''}`; + const response = await fetch(url, { + method: 'GET', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to fetch servers: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Get specific server by ID + */ + static async getServer(id: string): Promise { + const response = await fetch(`${API_BASE}/server/${id}`, { + method: 'GET', + credentials: 'include', + }); + + if (!response.ok) { + if (response.status === 404) { + throw new Error('Server not found'); + } + throw new Error(`Failed to fetch server: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Create new server + */ + static async createServer(serverData: CreateServerRequest): Promise { + const response = await fetch(`${API_BASE}/server`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + credentials: 'include', + body: JSON.stringify(serverData), + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => null); + throw new Error(errorData?.message || `Failed to create server: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Update existing server + */ + static async updateServer(id: string, serverData: UpdateServerRequest): Promise { + const response = await fetch(`${API_BASE}/server/${id}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + credentials: 'include', + body: JSON.stringify(serverData), + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => null); + throw new Error(errorData?.message || `Failed to update server: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Delete server + */ + static async deleteServer(id: string): Promise { + const response = await fetch(`${API_BASE}/server/${id}`, { + method: 'DELETE', + credentials: 'include', + }); + + if (!response.ok) { + if (response.status === 404) { + throw new Error('Server not found'); + } + throw new Error(`Failed to delete server: ${response.statusText}`); + } + } + + /** + * Get service status + */ + static async getServiceStatus(serverId: string, serviceName: string): Promise<{ status: string; state: string }> { + const response = await fetch(`${API_BASE}/server/${serverId}/service/${serviceName}`, { + method: 'GET', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to get service status: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Start server service + */ + static async startService(serverId: string): Promise { + const response = await fetch(`${API_BASE}/server/${serverId}/service/start`, { + method: 'POST', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to start service: ${response.statusText}`); + } + + return response.text(); + } + + /** + * Stop server service + */ + static async stopService(serverId: string): Promise { + const response = await fetch(`${API_BASE}/server/${serverId}/service/stop`, { + method: 'POST', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to stop service: ${response.statusText}`); + } + + return response.text(); + } + + /** + * Restart server service + */ + static async restartService(serverId: string): Promise { + const response = await fetch(`${API_BASE}/server/${serverId}/service/restart`, { + method: 'POST', + credentials: 'include', + }); + + if (!response.ok) { + throw new Error(`Failed to restart service: ${response.statusText}`); + } + + return response.text(); + } +} +``` + +### 3. Server Store (`src/stores/server.ts`) + +```typescript +import { writable, derived } from 'svelte/store'; +import type { Server, ServerFilter } from '../models/server'; +import { ServerApi } from '../lib/api/server'; + +interface ServerStore { + servers: Server[]; + selectedServer: Server | null; + isLoading: boolean; + error: string | null; + filter: ServerFilter; +} + +function createServerStore() { + const { subscribe, set, update } = writable({ + servers: [], + selectedServer: null, + isLoading: false, + error: null, + filter: {}, + }); + + const loadServers = async (filter?: ServerFilter) => { + update(store => ({ ...store, isLoading: true, error: null })); + + try { + const servers = await ServerApi.getServers(filter); + update(store => ({ + ...store, + servers, + filter: filter || {}, + isLoading: false, + })); + } catch (error) { + console.error('Failed to load servers:', error); + update(store => ({ + ...store, + error: error.message, + isLoading: false, + })); + } + }; + + const loadServer = async (id: string) => { + update(store => ({ ...store, isLoading: true, error: null })); + + try { + const server = await ServerApi.getServer(id); + update(store => ({ + ...store, + selectedServer: server, + isLoading: false, + })); + return server; + } catch (error) { + console.error('Failed to load server:', error); + update(store => ({ + ...store, + error: error.message, + isLoading: false, + })); + throw error; + } + }; + + const createServer = async (serverData: any) => { + update(store => ({ ...store, isLoading: true, error: null })); + + try { + const newServer = await ServerApi.createServer(serverData); + update(store => ({ + ...store, + servers: [...store.servers, newServer], + isLoading: false, + })); + return newServer; + } catch (error) { + console.error('Failed to create server:', error); + update(store => ({ + ...store, + error: error.message, + isLoading: false, + })); + throw error; + } + }; + + const updateServer = async (id: string, serverData: any) => { + update(store => ({ ...store, isLoading: true, error: null })); + + try { + const updatedServer = await ServerApi.updateServer(id, serverData); + update(store => ({ + ...store, + servers: store.servers.map(s => s.id === id ? updatedServer : s), + selectedServer: store.selectedServer?.id === id ? updatedServer : store.selectedServer, + isLoading: false, + })); + return updatedServer; + } catch (error) { + console.error('Failed to update server:', error); + update(store => ({ + ...store, + error: error.message, + isLoading: false, + })); + throw error; + } + }; + + const deleteServer = async (id: string) => { + update(store => ({ ...store, isLoading: true, error: null })); + + try { + await ServerApi.deleteServer(id); + update(store => ({ + ...store, + servers: store.servers.filter(s => s.id !== id), + selectedServer: store.selectedServer?.id === id ? null : store.selectedServer, + isLoading: false, + })); + } catch (error) { + console.error('Failed to delete server:', error); + update(store => ({ + ...store, + error: error.message, + isLoading: false, + })); + throw error; + } + }; + + const controlService = async (serverId: string, action: 'start' | 'stop' | 'restart') => { + try { + let result: string; + switch (action) { + case 'start': + result = await ServerApi.startService(serverId); + break; + case 'stop': + result = await ServerApi.stopService(serverId); + break; + case 'restart': + result = await ServerApi.restartService(serverId); + break; + } + + // Refresh server data after service control + await loadServer(serverId); + return result; + } catch (error) { + console.error(`Failed to ${action} service:`, error); + update(store => ({ ...store, error: error.message })); + throw error; + } + }; + + const clearError = () => { + update(store => ({ ...store, error: null })); + }; + + const setFilter = (filter: ServerFilter) => { + update(store => ({ ...store, filter })); + loadServers(filter); + }; + + return { + subscribe, + loadServers, + loadServer, + createServer, + updateServer, + deleteServer, + controlService, + clearError, + setFilter, + }; +} + +export const serverStore = createServerStore(); + +// Derived stores +export const servers = derived(serverStore, $store => $store.servers); +export const selectedServer = derived(serverStore, $store => $store.selectedServer); +export const isLoading = derived(serverStore, $store => $store.isLoading); +export const serverError = derived(serverStore, $store => $store.error); +``` + +### 4. Server List Page (`src/routes/servers/+page.svelte`) + +```svelte + + + + ACC Servers + + +
+ +
+

ACC Servers

+ +
+ + +
+
+ +
+
+ +
+
+ + + {#if $serverError} +
+ {$serverError} + +
+ {/if} + + + {#if $isLoading} +
+
+
+ {:else if filteredServers.length === 0} + +
+
+ {$servers.length === 0 ? 'No servers found' : 'No servers match your filter'} +
+ {#if $servers.length === 0} + + {/if} +
+ {:else} + +
+ {#each filteredServers as server (server.id)} +
+
+

{server.name}

+ + +
+ {server.status} +
+ + +
+
Service: {server.serviceName}
+
Created: {formatDate(server.dateCreated)}
+ {#if server.state} +
Players: {server.state.playerCount}/{server.state.maxConnections}
+
Track: {server.state.track || 'Unknown'}
+ {/if} +
+ + +
+ + + +
+
+
+ {/each} +
+ {/if} +
+ + +``` + +### 5. Create Server Page (`src/routes/servers/create/+page.svelte`) + +```svelte + + + + Create Server - ACC Server Manager + + +
+

Create New ACC Server

+ +
+
+
+ + {#if errors.general} +
+ {errors.general} +
+ {/if} + + +
+ + + {#if errors.name} + + {/if} +
+ + +
+
+

What happens when you create a server?

+
+
    +
  • ACC server files will be downloaded via SteamCMD
  • +
  • A Windows service will be created
  • +
  • Default configuration files will be generated
  • +
  • You may need to confirm Steam 2FA if prompted
  • +
+
+
+
+ + +
+ + +
+
+
+
+ + +
+
+
+ + + +
+
+

Steam Authentication Required

+
+

If your Steam account has 2FA enabled, you'll need to confirm the login request in your Steam Mobile App when prompted.

+
+
+
+
+
+``` + +### 6. Server Detail Page (`src/routes/servers/[id]/+page.svelte`) + +```svelte + + + + {$selectedServer?.name || 'Server'} - ACC Server Manager + + +
+ + {#if $isLoading} +
+
+
+ {:else if $serverError} + +
+ {$serverError} + +
+ {:else if $selectedServer} + +
+
+

{$selectedServer.name}

+
+
+ {$selectedServer.status} +
+ Service: {$selectedServer.serviceName} +
+
+ +
+ + +
+
+ + +
+
+

Service Control

+ + {#if controlError} +
+ {controlError} + +
+ {/if} + +
+ + + + + +
+
+
+ + +
+ +
+
+

Server Information

+
+
ID: {$selectedServer.id}
+
Name: {$selectedServer.name}
+
Service Name: {$selectedServer.serviceName}
+
Path: {$selectedServer.path}
+
Created: {formatDate($selectedServer.dateCreated)}
+
+
+
+ + + {#if $selectedServer.state} +
+
+

Current State

+
+
Session: {$selectedServer.state.session}
+
Track: {$selectedServer.state.track || 'Unknown'}
+
Players: {$selectedServer.state.playerCount}/{$selectedServer.state.maxConnections}
+
Session Duration: {$selectedServer.state.sessionDurationMinutes} minutes
+
Session Started: {formatDate($selectedServer.state.sessionStart)}
+
+
+
+ {:else} +
+
+

Current State

+

No state information available

+
+
+ {/if} +
+ + +
+
+

Configuration

+

Manage server configuration files

+ +
+
+ {/if} +
+``` + +### 7. Navigation Integration + +Add server management links to your main navigation (`src/lib/components/Navigation.svelte`): + +```svelte + +``` + +## Routes Structure + +Ensure your routing structure includes: + +``` +src/routes/ +├── +layout.svelte # Main layout with Steam2FA component +├── servers/ +│ ├── +page.svelte # Server list page +│ ├── create/ +│ │ └── +page.svelte # Create server page +│ └── [id]/ +│ ├── +page.svelte # Server detail page +│ ├── edit/ +│ │ └── +page.svelte # Edit server page +│ └── config/ +│ └── +page.svelte # Configuration management page +``` + +## Key Features Implemented + +1. **Complete CRUD Operations**: Create, read, update, delete servers +2. **Service Control**: Start, stop, restart server services +3. **Real-time Status**: Display current server status and state +4. **Search and Filtering**: Find servers by name or status +5. **Error Handling**: Comprehensive error messages and recovery +6. **Steam 2FA Integration**: Automatic handling during server creation +7. **Responsive Design**: Works on desktop and mobile devices +8. **Loading States**: Clear feedback during async operations + +## Authentication Integration + +The API client automatically includes authentication cookies with all requests. Ensure your authentication system is set up to handle: + +1. User login/logout +2. Session management +3. Permission checking (ServerView, ServerCreate, ServerUpdate, ServerDelete) + +## Conclusion + +This implementation provides a complete, production-ready Steam 2FA system for the frontend. The polling-based approach ensures compatibility with all browsers and provides reliable 2FA handling for Steam operations. + +The system is designed to be: +- **User-friendly**: Clear instructions and immediate feedback +- **Robust**: Comprehensive error handling and recovery +- **Maintainable**: Clean separation of concerns and well-documented code +- **Secure**: Proper authentication and no credential exposure + +Once implemented, users will receive immediate notifications when Steam 2FA is required, making server management seamless and intuitive. \ No newline at end of file diff --git a/local/controller/server.go b/local/controller/server.go index 1101971..344b4cb 100644 --- a/local/controller/server.go +++ b/local/controller/server.go @@ -29,7 +29,6 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro serverRoutes.Get("/", auth.HasPermission(model.ServerView), ac.GetAll) serverRoutes.Get("/:id", auth.HasPermission(model.ServerView), ac.GetById) serverRoutes.Post("/", auth.HasPermission(model.ServerCreate), ac.CreateServer) - serverRoutes.Put("/:id", auth.HasPermission(model.ServerUpdate), ac.UpdateServer) serverRoutes.Delete("/:id", auth.HasPermission(model.ServerDelete), ac.DeleteServer) apiServerRoutes := routeGroups.Api.Group("/server") @@ -152,41 +151,6 @@ func (ac *ServerController) CreateServer(c *fiber.Ctx) error { return c.JSON(server) } -// UpdateServer updates an existing server -// @Summary Update an ACC server -// @Description Update configuration for an existing ACC server -// @Tags Server -// @Accept json -// @Produce json -// @Param id path string true "Server ID (UUID format)" -// @Param server body model.Server true "Updated server configuration" -// @Success 200 {object} object "Updated server details" -// @Failure 400 {object} error_handler.ErrorResponse "Invalid server data or ID" -// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized" -// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions" -// @Failure 404 {object} error_handler.ErrorResponse "Server not found" -// @Failure 500 {object} error_handler.ErrorResponse "Internal server error" -// @Security BearerAuth -// @Router /server/{id} [put] -func (ac *ServerController) UpdateServer(c *fiber.Ctx) error { - serverIDStr := c.Params("id") - serverID, err := uuid.Parse(serverIDStr) - if err != nil { - return ac.errorHandler.HandleUUIDError(c, "server ID") - } - - server := new(model.Server) - if err := c.BodyParser(server); err != nil { - return ac.errorHandler.HandleParsingError(c, err) - } - server.ID = serverID - - if err := ac.service.UpdateServer(c, server); err != nil { - return ac.errorHandler.HandleServiceError(c, err) - } - return c.JSON(server) -} - // DeleteServer deletes an existing server // @Summary Delete an ACC server // @Description Delete an existing ACC server diff --git a/local/controller/websocket.go b/local/controller/websocket.go index 8ae5c47..1e10b51 100644 --- a/local/controller/websocket.go +++ b/local/controller/websocket.go @@ -17,7 +17,6 @@ type WebSocketController struct { jwtHandler *jwt.OpenJWTHandler } -// NewWebSocketController initializes WebSocketController func NewWebSocketController( wsService *service.WebSocketService, jwtHandler *jwt.OpenJWTHandler, @@ -29,7 +28,6 @@ func NewWebSocketController( jwtHandler: jwtHandler, } - // WebSocket routes wsRoutes := routeGroups.WebSocket wsRoutes.Use("/", wsc.upgradeWebSocket) wsRoutes.Get("/", websocket.New(wsc.handleWebSocket)) @@ -37,11 +35,8 @@ func NewWebSocketController( return wsc } -// upgradeWebSocket middleware to upgrade HTTP to WebSocket and validate authentication func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error { - // Check if it's a WebSocket upgrade request if websocket.IsWebSocketUpgrade(c) { - // Validate JWT token from query parameter or header token := c.Query("token") if token == "" { token = c.Get("Authorization") @@ -54,21 +49,18 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUnauthorized, "Missing authentication token") } - // Validate the token claims, err := wsc.jwtHandler.ValidateToken(token) if err != nil { return fiber.NewError(fiber.StatusUnauthorized, "Invalid authentication token") } - // Parse UserID string to UUID userID, err := uuid.Parse(claims.UserID) if err != nil { return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID in token") } - // Store user info in context for use in WebSocket handler c.Locals("userID", userID) - c.Locals("username", claims.UserID) // Use UserID as username for now + c.Locals("username", claims.UserID) return c.Next() } @@ -76,12 +68,9 @@ func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required") } -// handleWebSocket handles WebSocket connections func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { - // Generate a unique connection ID connID := uuid.New().String() - // Get user info from locals (set by middleware) userID, ok := c.Locals("userID").(uuid.UUID) if !ok { logging.Error("Failed to get user ID from WebSocket connection") @@ -92,16 +81,13 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { username, _ := c.Locals("username").(string) logging.Info("WebSocket connection established for user: %s (ID: %s)", username, userID.String()) - // Add the connection to the service wsc.webSocketService.AddConnection(connID, c, &userID) - // Handle connection cleanup defer func() { wsc.webSocketService.RemoveConnection(connID) logging.Info("WebSocket connection closed for user: %s", username) }() - // Handle incoming messages from the client for { messageType, message, err := c.ReadMessage() if err != nil { @@ -111,14 +97,12 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { break } - // Handle different message types switch messageType { case websocket.TextMessage: wsc.handleTextMessage(connID, userID, message) case websocket.BinaryMessage: logging.Debug("Received binary message from user %s (not supported)", username) case websocket.PingMessage: - // Respond with pong if err := c.WriteMessage(websocket.PongMessage, nil); err != nil { logging.Error("Failed to send pong to user %s: %v", username, err) break @@ -127,21 +111,11 @@ func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) { } } -// handleTextMessage processes text messages from the client func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUID, message []byte) { logging.Debug("Received WebSocket message from user %s: %s", userID.String(), string(message)) - // Parse the message to handle different types of client requests - // For now, we'll just log it. In the future, you might want to handle: - // - Subscription to specific server creation processes - // - Client heartbeat/keepalive - // - Request for status updates - - // Example: If the message contains a server ID, associate this connection with that server - // This is a simple implementation - you might want to use proper JSON parsing messageStr := string(message) if len(messageStr) > 10 && messageStr[:9] == "server_id" { - // Extract server ID from message like "server_id:uuid" if serverIDStr := messageStr[10:]; len(serverIDStr) > 0 { if serverID, err := uuid.Parse(serverIDStr); err == nil { wsc.webSocketService.SetServerID(connID, serverID) @@ -151,18 +125,14 @@ func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUI } } -// GetWebSocketUpgrade returns the WebSocket upgrade handler for use in other controllers func (wsc *WebSocketController) GetWebSocketUpgrade() fiber.Handler { return wsc.upgradeWebSocket } -// GetWebSocketHandler returns the WebSocket connection handler for use in other controllers func (wsc *WebSocketController) GetWebSocketHandler() func(*websocket.Conn) { return wsc.handleWebSocket } -// BroadcastServerCreationProgress is a helper method for other services to broadcast progress func (wsc *WebSocketController) BroadcastServerCreationProgress(serverID uuid.UUID, step string, status string, message string) { - // This can be used by the ServerService during server creation logging.Info("Broadcasting server creation progress: %s - %s: %s", serverID.String(), step, status) } diff --git a/local/middleware/access_key.go b/local/middleware/access_key.go index 3129639..7256b44 100644 --- a/local/middleware/access_key.go +++ b/local/middleware/access_key.go @@ -10,12 +10,10 @@ import ( "github.com/google/uuid" ) -// AccessKeyMiddleware provides authentication and permission middleware. type AccessKeyMiddleware struct { userInfo CachedUserInfo } -// NewAccessKeyMiddleware creates a new AccessKeyMiddleware. func NewAccessKeyMiddleware() *AccessKeyMiddleware { auth := &AccessKeyMiddleware{ userInfo: CachedUserInfo{UserID: uuid.New().String(), Username: "access_key", RoleName: "Admin", Permissions: map[string]bool{ @@ -25,9 +23,7 @@ func NewAccessKeyMiddleware() *AccessKeyMiddleware { return auth } -// Authenticate is a middleware for JWT authentication with enhanced security. func (m *AccessKeyMiddleware) Authenticate(ctx *fiber.Ctx) error { - // Log authentication attempt ip := ctx.IP() userAgent := ctx.Get("User-Agent") diff --git a/local/middleware/auth.go b/local/middleware/auth.go index db8fd73..467d966 100644 --- a/local/middleware/auth.go +++ b/local/middleware/auth.go @@ -16,7 +16,6 @@ import ( "github.com/google/uuid" ) -// CachedUserInfo holds cached user authentication and permission data type CachedUserInfo struct { UserID string Username string @@ -25,7 +24,6 @@ type CachedUserInfo struct { CachedAt time.Time } -// AuthMiddleware provides authentication and permission middleware. type AuthMiddleware struct { membershipService *service.MembershipService cache *cache.InMemoryCache @@ -34,7 +32,6 @@ type AuthMiddleware struct { openJWTHandler *jwt.OpenJWTHandler } -// NewAuthMiddleware creates a new AuthMiddleware. func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware { auth := &AuthMiddleware{ membershipService: ms, @@ -44,24 +41,20 @@ func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache openJWTHandler: openJWTHandler, } - // Set up bidirectional relationship for cache invalidation ms.SetCacheInvalidator(auth) return auth } -// Authenticate is a middleware for JWT authentication with enhanced security. func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error { return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, true, ctx) } -// Authenticate is a middleware for JWT authentication with enhanced security. func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error { return m.AuthenticateWithHandler(m.jwtHandler, false, ctx) } func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isOpenToken bool, ctx *fiber.Ctx) error { - // Log authentication attempt ip := ctx.IP() userAgent := ctx.Get("User-Agent") @@ -89,7 +82,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO }) } - // Validate token length to prevent potential attacks token := parts[1] if len(token) < 10 || len(token) > 2048 { logging.Error("Authentication failed: invalid token length from IP %s", ip) @@ -113,7 +105,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO }) } - // Additional security: validate user ID format if claims.UserID == "" || len(claims.UserID) < 10 { logging.Error("Authentication failed: invalid user ID in token from IP %s", ip) return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ @@ -127,7 +118,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO ctx.Locals("userInfo", userInfo) ctx.Locals("authTime", time.Now()) } else { - // Preload and cache user info to avoid database queries on permission checks userInfo, err := m.getCachedUserInfo(ctx.UserContext(), claims.UserID) if err != nil { logging.Error("Authentication failed: unable to load user info for %s from IP %s: %v", claims.UserID, ip, err) @@ -145,7 +135,6 @@ func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isO return ctx.Next() } -// HasPermission is a middleware for checking user permissions with enhanced logging. func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler { return func(ctx *fiber.Ctx) error { userID, ok := ctx.Locals("userID").(string) @@ -160,7 +149,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler return ctx.Next() } - // Validate permission parameter if requiredPermission == "" { logging.Error("Permission check failed: empty permission requirement") return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ @@ -168,7 +156,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler }) } - // Use cached user info from authentication step - no database queries needed userInfo, ok := ctx.Locals("userInfo").(*CachedUserInfo) if !ok { logging.Error("Permission check failed: no cached user info in context from IP %s", ctx.IP()) @@ -177,7 +164,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler }) } - // Check if user has permission using cached data has := m.hasPermissionFromCache(userInfo, requiredPermission) if !has { @@ -192,16 +178,13 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler } } -// AuthRateLimit applies rate limiting specifically for authentication endpoints func (m *AuthMiddleware) AuthRateLimit() fiber.Handler { return m.securityMW.AuthRateLimit() } -// RequireHTTPS redirects HTTP requests to HTTPS in production func (m *AuthMiddleware) RequireHTTPS() fiber.Handler { return func(ctx *fiber.Ctx) error { if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" { - // Allow HTTP in development/testing if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" { httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL() return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently) @@ -211,11 +194,9 @@ func (m *AuthMiddleware) RequireHTTPS() fiber.Handler { } } -// getCachedUserInfo retrieves and caches complete user information including permissions func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (*CachedUserInfo, error) { cacheKey := fmt.Sprintf("userinfo:%s", userID) - // Try cache first if cached, found := m.cache.Get(cacheKey); found { if userInfo, ok := cached.(*CachedUserInfo); ok { logging.DebugWithContext("AUTH_CACHE", "User info for %s found in cache", userID) @@ -223,13 +204,11 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) ( } } - // Cache miss - load from database user, err := m.membershipService.GetUserWithPermissions(ctx, userID) if err != nil { return nil, err } - // Build permission map for fast lookups permissions := make(map[string]bool) for _, p := range user.Role.Permissions { permissions[p.Name] = true @@ -243,34 +222,26 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) ( CachedAt: time.Now(), } - // Cache for 15 minutes m.cache.Set(cacheKey, userInfo, 15*time.Minute) logging.DebugWithContext("AUTH_CACHE", "User info for %s cached with %d permissions", userID, len(permissions)) return userInfo, nil } -// hasPermissionFromCache checks permissions using cached user info (no database queries) func (m *AuthMiddleware) hasPermissionFromCache(userInfo *CachedUserInfo, permission string) bool { - // Super Admin and Admin have all permissions if userInfo.RoleName == "Super Admin" || userInfo.RoleName == "Admin" { return true } - // Check specific permission in cached map return userInfo.Permissions[permission] } -// InvalidateUserPermissions removes cached user info for a user func (m *AuthMiddleware) InvalidateUserPermissions(userID string) { cacheKey := fmt.Sprintf("userinfo:%s", userID) m.cache.Delete(cacheKey) logging.InfoWithContext("AUTH_CACHE", "User info cache invalidated for user %s", userID) } -// InvalidateAllUserPermissions clears all cached user info (useful for role/permission changes) func (m *AuthMiddleware) InvalidateAllUserPermissions() { - // This would need to be implemented based on your cache interface - // For now, just log that invalidation was requested logging.InfoWithContext("AUTH_CACHE", "All user info caches invalidation requested") } diff --git a/local/middleware/logging/request_logging.go b/local/middleware/logging/request_logging.go index 7c31e3b..8d92c4b 100644 --- a/local/middleware/logging/request_logging.go +++ b/local/middleware/logging/request_logging.go @@ -7,25 +7,20 @@ import ( "github.com/gofiber/fiber/v2" ) -// RequestLoggingMiddleware logs HTTP requests and responses type RequestLoggingMiddleware struct { infoLogger *logging.InfoLogger } -// NewRequestLoggingMiddleware creates a new request logging middleware func NewRequestLoggingMiddleware() *RequestLoggingMiddleware { return &RequestLoggingMiddleware{ infoLogger: logging.GetInfoLogger(), } } -// Handler returns the middleware handler function func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler { return func(c *fiber.Ctx) error { - // Record start time start := time.Now() - // Log incoming request userAgent := c.Get("User-Agent") if userAgent == "" { userAgent = "Unknown" @@ -33,17 +28,13 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler { rlm.infoLogger.LogRequest(c.Method(), c.OriginalURL(), userAgent) - // Continue to next handler err := c.Next() - // Calculate duration duration := time.Since(start) - // Log response statusCode := c.Response().StatusCode() rlm.infoLogger.LogResponse(c.Method(), c.OriginalURL(), statusCode, duration.String()) - // Log error if present if err != nil { logging.ErrorWithContext("REQUEST_MIDDLEWARE", "Request failed: %v", err) } @@ -52,10 +43,8 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler { } } -// Global request logging middleware instance var globalRequestLoggingMiddleware *RequestLoggingMiddleware -// GetRequestLoggingMiddleware returns the global request logging middleware func GetRequestLoggingMiddleware() *RequestLoggingMiddleware { if globalRequestLoggingMiddleware == nil { globalRequestLoggingMiddleware = NewRequestLoggingMiddleware() @@ -63,7 +52,6 @@ func GetRequestLoggingMiddleware() *RequestLoggingMiddleware { return globalRequestLoggingMiddleware } -// Handler returns the global request logging middleware handler func Handler() fiber.Handler { return GetRequestLoggingMiddleware().Handler() } diff --git a/local/middleware/security/security.go b/local/middleware/security/security.go index 00899c4..78a9182 100644 --- a/local/middleware/security/security.go +++ b/local/middleware/security/security.go @@ -11,19 +11,16 @@ import ( "github.com/gofiber/fiber/v2" ) -// RateLimiter stores rate limiting information type RateLimiter struct { requests map[string][]time.Time mutex sync.RWMutex } -// NewRateLimiter creates a new rate limiter func NewRateLimiter() *RateLimiter { rl := &RateLimiter{ requests: make(map[string][]time.Time), } - // Use graceful shutdown for cleanup goroutine shutdownManager := graceful.GetManager() shutdownManager.RunGoroutine(func(ctx context.Context) { rl.cleanupWithContext(ctx) @@ -32,7 +29,6 @@ func NewRateLimiter() *RateLimiter { return rl } -// cleanup removes old entries from the rate limiter func (rl *RateLimiter) cleanupWithContext(ctx context.Context) { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() @@ -62,44 +58,29 @@ func (rl *RateLimiter) cleanupWithContext(ctx context.Context) { } } -// SecurityMiddleware provides comprehensive security middleware type SecurityMiddleware struct { rateLimiter *RateLimiter } -// NewSecurityMiddleware creates a new security middleware func NewSecurityMiddleware() *SecurityMiddleware { return &SecurityMiddleware{ rateLimiter: NewRateLimiter(), } } -// SecurityHeaders adds security headers to responses func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler { return func(c *fiber.Ctx) error { - // Prevent MIME type sniffing c.Set("X-Content-Type-Options", "nosniff") - - // Prevent clickjacking c.Set("X-Frame-Options", "DENY") - - // Enable XSS protection c.Set("X-XSS-Protection", "1; mode=block") - - // Prevent referrer leakage c.Set("Referrer-Policy", "strict-origin-when-cross-origin") - - // Content Security Policy c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'") - - // Permissions Policy c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()") return c.Next() } } -// RateLimit implements rate limiting for API endpoints func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler { return func(c *fiber.Ctx) error { ip := c.IP() @@ -111,7 +92,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) now := time.Now() requests := sm.rateLimiter.requests[key] - // Remove requests older than duration filtered := make([]time.Time, 0, len(requests)) for _, t := range requests { if now.Sub(t) < duration { @@ -119,7 +99,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) } } - // Check if limit is exceeded if len(filtered) >= maxRequests { return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{ "error": "Rate limit exceeded", @@ -127,7 +106,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) }) } - // Add current request filtered = append(filtered, now) sm.rateLimiter.requests[key] = filtered @@ -135,7 +113,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) } } -// AuthRateLimit implements stricter rate limiting for authentication endpoints func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { return func(c *fiber.Ctx) error { ip := c.IP() @@ -148,7 +125,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { now := time.Now() requests := sm.rateLimiter.requests[key] - // Remove requests older than 15 minutes filtered := make([]time.Time, 0, len(requests)) for _, t := range requests { if now.Sub(t) < 15*time.Minute { @@ -156,7 +132,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { } } - // Check if limit is exceeded (5 requests per 15 minutes for auth) if len(filtered) >= 5 { return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{ "error": "Too many authentication attempts", @@ -164,7 +139,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { }) } - // Add current request filtered = append(filtered, now) sm.rateLimiter.requests[key] = filtered @@ -172,20 +146,16 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler { } } -// InputSanitization sanitizes user input to prevent XSS and injection attacks func (sm *SecurityMiddleware) InputSanitization() fiber.Handler { return func(c *fiber.Ctx) error { - // Sanitize query parameters c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) { sanitized := sanitizeInput(string(value)) c.Request().URI().QueryArgs().Set(string(key), sanitized) }) - // Store original body for processing if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" { body := c.Body() if len(body) > 0 { - // Basic sanitization - remove potentially dangerous patterns sanitized := sanitizeInput(string(body)) c.Request().SetBodyString(sanitized) } @@ -195,7 +165,6 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler { } } -// sanitizeInput removes potentially dangerous patterns from input func sanitizeInput(input string) string { dangerous := []string{ " 100 || containsBinaryData(plainPassword) { return fmt.Errorf("password appears to be corrupted encrypted data") } @@ -144,7 +122,6 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u plainPassword = decrypted } - // Validate plain password if plainPassword == "" { return errors.New("decrypted password is empty") } @@ -153,13 +130,11 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u return errors.New("password too short after decryption") } - // Hash the plain password using bcrypt hashedPassword, err := password.HashPassword(plainPassword) if err != nil { return fmt.Errorf("failed to hash password: %v", err) } - // Update with hashed password if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil { return fmt.Errorf("failed to update password: %v", err) } @@ -168,19 +143,16 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u return nil } -// UserForMigration represents a user record for migration purposes type UserForMigration struct { ID string `gorm:"column:id"` Username string `gorm:"column:username"` Password string `gorm:"column:password"` } -// TableName specifies the table name for GORM func (UserForMigration) TableName() string { return "users" } -// MigrationRecord tracks applied migrations type MigrationRecord struct { ID uint `gorm:"primaryKey"` MigrationName string `gorm:"unique;not null"` @@ -189,49 +161,38 @@ type MigrationRecord struct { Notes string } -// TableName specifies the table name for GORM func (MigrationRecord) TableName() string { return "migration_records" } -// isAlreadyHashed checks if a password is already bcrypt hashed func isAlreadyHashed(password string) bool { return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$") } -// containsBinaryData checks if a string contains binary data func containsBinaryData(s string) bool { for _, b := range []byte(s) { - if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return + if b < 32 && b != 9 && b != 10 && b != 13 { return true } } return false } -// isDuplicateColumnError checks if an error is due to duplicate column func isDuplicateColumnError(err error) bool { errStr := err.Error() return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" || fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup" } -// decryptOldPassword attempts to decrypt using the old encryption method -// This is a simplified version of the old DecryptPassword function func decryptOldPassword(encryptedPassword string) (string, error) { - // This would use the old decryption logic - // For now, we'll return an error to force treating as plain text - // In a real scenario, you'd implement the old decryption here return "", errors.New("old decryption not implemented - treating as plain text") } -// Down reverses the migration (if needed) func (m *Migration001UpgradePasswordSecurity) Down() error { logging.Error("Password security migration rollback is not supported for security reasons") return errors.New("password security migration rollback is not supported") } -// RunMigration is a convenience function to run the migration func RunPasswordSecurityMigration(db *gorm.DB) error { migration := NewMigration001UpgradePasswordSecurity(db) return migration.Up() diff --git a/local/migrations/002_migrate_to_uuid.go b/local/migrations/002_migrate_to_uuid.go index 7cf915d..2387dbd 100644 --- a/local/migrations/002_migrate_to_uuid.go +++ b/local/migrations/002_migrate_to_uuid.go @@ -9,21 +9,17 @@ import ( "gorm.io/gorm" ) -// Migration002MigrateToUUID migrates tables from integer IDs to UUIDs type Migration002MigrateToUUID struct { DB *gorm.DB } -// NewMigration002MigrateToUUID creates a new UUID migration func NewMigration002MigrateToUUID(db *gorm.DB) *Migration002MigrateToUUID { return &Migration002MigrateToUUID{DB: db} } -// Up executes the migration func (m *Migration002MigrateToUUID) Up() error { logging.Info("Checking UUID migration...") - // Check if migration is needed by looking at the servers table structure if !m.needsMigration() { logging.Info("UUID migration not needed - tables already use UUID primary keys") return nil @@ -31,7 +27,6 @@ func (m *Migration002MigrateToUUID) Up() error { logging.Info("Starting UUID migration...") - // Check if migration has already been applied var migrationRecord MigrationRecord err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error if err == nil { @@ -39,12 +34,10 @@ func (m *Migration002MigrateToUUID) Up() error { return nil } - // Create migration tracking table if it doesn't exist if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil { return fmt.Errorf("failed to create migration tracking table: %v", err) } - // Execute the UUID migration using the existing migration function logging.Info("Executing UUID migration...") if err := runUUIDMigrationSQL(m.DB); err != nil { return fmt.Errorf("failed to execute UUID migration: %v", err) @@ -54,9 +47,7 @@ func (m *Migration002MigrateToUUID) Up() error { return nil } -// needsMigration checks if the UUID migration is needed by examining table structure func (m *Migration002MigrateToUUID) needsMigration() bool { - // Check if servers table exists and has integer primary key var result struct { Type string `gorm:"column:type"` } @@ -67,29 +58,22 @@ func (m *Migration002MigrateToUUID) needsMigration() bool { `).Scan(&result).Error if err != nil || result.Type == "" { - // Table doesn't exist or no primary key found - assume no migration needed return false } - // If the primary key is INTEGER, we need migration - // If it's TEXT (UUID), migration already done return result.Type == "INTEGER" || result.Type == "integer" } -// Down reverses the migration (not implemented for safety) func (m *Migration002MigrateToUUID) Down() error { logging.Error("UUID migration rollback is not supported for data safety reasons") return fmt.Errorf("UUID migration rollback is not supported") } -// runUUIDMigrationSQL executes the UUID migration using the SQL file func runUUIDMigrationSQL(db *gorm.DB) error { - // Disable foreign key constraints during migration if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil { return fmt.Errorf("failed to disable foreign keys: %v", err) } - // Start transaction tx := db.Begin() if tx.Error != nil { return fmt.Errorf("failed to start transaction: %v", tx.Error) @@ -101,25 +85,21 @@ func runUUIDMigrationSQL(db *gorm.DB) error { } }() - // Read the migration SQL from file sqlPath := filepath.Join("scripts", "migrations", "002_migrate_servers_to_uuid.sql") migrationSQL, err := ioutil.ReadFile(sqlPath) if err != nil { return fmt.Errorf("failed to read migration SQL file: %v", err) } - // Execute the migration if err := tx.Exec(string(migrationSQL)).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to execute migration: %v", err) } - // Commit transaction if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration: %v", err) } - // Re-enable foreign key constraints if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil { return fmt.Errorf("failed to re-enable foreign keys: %v", err) } @@ -127,7 +107,6 @@ func runUUIDMigrationSQL(db *gorm.DB) error { return nil } -// RunUUIDMigration is a convenience function to run the migration func RunUUIDMigration(db *gorm.DB) error { migration := NewMigration002MigrateToUUID(db) return migration.Up() diff --git a/local/migrations/003_update_state_history_sessions.go b/local/migrations/003_update_state_history_sessions.go index 0092c22..3fd53bc 100644 --- a/local/migrations/003_update_state_history_sessions.go +++ b/local/migrations/003_update_state_history_sessions.go @@ -7,21 +7,17 @@ import ( "gorm.io/gorm" ) -// UpdateStateHistorySessions migrates tables from integer IDs to UUIDs type UpdateStateHistorySessions struct { DB *gorm.DB } -// NewUpdateStateHistorySessions creates a new UUID migration func NewUpdateStateHistorySessions(db *gorm.DB) *UpdateStateHistorySessions { return &UpdateStateHistorySessions{DB: db} } -// Up executes the migration func (m *UpdateStateHistorySessions) Up() error { logging.Info("Checking UUID migration...") - // Check if migration is needed by looking at the servers table structure if !m.needsMigration() { logging.Info("UUID migration not needed - tables already use UUID primary keys") return nil @@ -29,7 +25,6 @@ func (m *UpdateStateHistorySessions) Up() error { logging.Info("Starting UUID migration...") - // Check if migration has already been applied var migrationRecord MigrationRecord err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error if err == nil { @@ -37,12 +32,10 @@ func (m *UpdateStateHistorySessions) Up() error { return nil } - // Create migration tracking table if it doesn't exist if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil { return fmt.Errorf("failed to create migration tracking table: %v", err) } - // Execute the UUID migration using the existing migration function logging.Info("Executing UUID migration...") if err := runUUIDMigrationSQL(m.DB); err != nil { return fmt.Errorf("failed to execute UUID migration: %v", err) @@ -52,9 +45,7 @@ func (m *UpdateStateHistorySessions) Up() error { return nil } -// needsMigration checks if the UUID migration is needed by examining table structure func (m *UpdateStateHistorySessions) needsMigration() bool { - // Check if servers table exists and has integer primary key var result struct { Exists bool `gorm:"column:exists"` } @@ -65,26 +56,21 @@ func (m *UpdateStateHistorySessions) needsMigration() bool { `).Scan(&result).Error if err != nil || !result.Exists { - // Table doesn't exist or no primary key found - assume no migration needed return false } return result.Exists } -// Down reverses the migration (not implemented for safety) func (m *UpdateStateHistorySessions) Down() error { logging.Error("UUID migration rollback is not supported for data safety reasons") return fmt.Errorf("UUID migration rollback is not supported") } -// runUpdateStateHistorySessionsMigration executes the UUID migration using the SQL file func runUpdateStateHistorySessionsMigration(db *gorm.DB) error { - // Disable foreign key constraints during migration if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil { return fmt.Errorf("failed to disable foreign keys: %v", err) } - // Start transaction tx := db.Begin() if tx.Error != nil { return fmt.Errorf("failed to start transaction: %v", tx.Error) @@ -98,18 +84,15 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error { migrationSQL := "UPDATE state_history SET session = upper(substr(session, 1, 1));" - // Execute the migration if err := tx.Exec(string(migrationSQL)).Error; err != nil { tx.Rollback() return fmt.Errorf("failed to execute migration: %v", err) } - // Commit transaction if err := tx.Commit().Error; err != nil { return fmt.Errorf("failed to commit migration: %v", err) } - // Re-enable foreign key constraints if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil { return fmt.Errorf("failed to re-enable foreign keys: %v", err) } @@ -117,7 +100,6 @@ func runUpdateStateHistorySessionsMigration(db *gorm.DB) error { return nil } -// RunUpdateStateHistorySessionsMigration is a convenience function to run the migration func RunUpdateStateHistorySessionsMigration(db *gorm.DB) error { migration := NewUpdateStateHistorySessions(db) return migration.Up() diff --git a/local/model/cache.go b/local/model/cache.go index 5b1a1eb..989d57e 100644 --- a/local/model/cache.go +++ b/local/model/cache.go @@ -6,20 +6,17 @@ import ( "time" ) -// StatusCache represents a cached server status with expiration type StatusCache struct { Status ServiceStatus UpdatedAt time.Time } -// CacheConfig holds configuration for cache behavior type CacheConfig struct { - ExpirationTime time.Duration // How long before a cache entry expires - ThrottleTime time.Duration // Minimum time between status checks - DefaultStatus ServiceStatus // Default status to return when throttled + ExpirationTime time.Duration + ThrottleTime time.Duration + DefaultStatus ServiceStatus } -// ServerStatusCache manages cached server statuses type ServerStatusCache struct { sync.RWMutex cache map[string]*StatusCache @@ -27,7 +24,6 @@ type ServerStatusCache struct { lastChecked map[string]time.Time } -// NewServerStatusCache creates a new server status cache func NewServerStatusCache(config CacheConfig) *ServerStatusCache { return &ServerStatusCache{ cache: make(map[string]*StatusCache), @@ -36,12 +32,10 @@ func NewServerStatusCache(config CacheConfig) *ServerStatusCache { } } -// GetStatus retrieves the cached status or indicates if a fresh check is needed func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) { c.RLock() defer c.RUnlock() - // Check if we're being throttled if lastCheck, exists := c.lastChecked[serviceName]; exists { if time.Since(lastCheck) < c.config.ThrottleTime { if cached, ok := c.cache[serviceName]; ok { @@ -51,7 +45,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) } } - // Check if we have a valid cached entry if cached, ok := c.cache[serviceName]; ok { if time.Since(cached.UpdatedAt) < c.config.ExpirationTime { return cached.Status, false @@ -61,7 +54,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) return StatusUnknown, true } -// UpdateStatus updates the cache with a new status func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatus) { c.Lock() defer c.Unlock() @@ -73,7 +65,6 @@ func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatu c.lastChecked[serviceName] = time.Now() } -// InvalidateStatus removes a specific service from the cache func (c *ServerStatusCache) InvalidateStatus(serviceName string) { c.Lock() defer c.Unlock() @@ -82,7 +73,6 @@ func (c *ServerStatusCache) InvalidateStatus(serviceName string) { delete(c.lastChecked, serviceName) } -// Clear removes all entries from the cache func (c *ServerStatusCache) Clear() { c.Lock() defer c.Unlock() @@ -91,13 +81,11 @@ func (c *ServerStatusCache) Clear() { c.lastChecked = make(map[string]time.Time) } -// LookupCache provides a generic cache for lookup data type LookupCache struct { sync.RWMutex data map[string]interface{} } -// NewLookupCache creates a new lookup cache func NewLookupCache() *LookupCache { logging.Debug("Initializing new LookupCache") return &LookupCache{ @@ -105,7 +93,6 @@ func NewLookupCache() *LookupCache { } } -// Get retrieves a cached value by key func (c *LookupCache) Get(key string) (interface{}, bool) { c.RLock() defer c.RUnlock() @@ -119,7 +106,6 @@ func (c *LookupCache) Get(key string) (interface{}, bool) { return value, exists } -// Set stores a value in the cache func (c *LookupCache) Set(key string, value interface{}) { c.Lock() defer c.Unlock() @@ -128,7 +114,6 @@ func (c *LookupCache) Set(key string, value interface{}) { logging.Debug("Cache SET for key: %s", key) } -// Clear removes all entries from the cache func (c *LookupCache) Clear() { c.Lock() defer c.Unlock() @@ -137,13 +122,11 @@ func (c *LookupCache) Clear() { logging.Debug("Cache CLEARED") } -// ConfigEntry represents a cached configuration entry with its update time type ConfigEntry[T any] struct { Data T UpdatedAt time.Time } -// getConfigFromCache is a generic helper function to retrieve cached configs func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string, expirationTime time.Duration) (*T, bool) { if entry, ok := cache[serverID]; ok { if time.Since(entry.UpdatedAt) < expirationTime { @@ -157,7 +140,6 @@ func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string return nil, false } -// updateConfigInCache is a generic helper function to update cached configs func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID string, data T) { cache[serverID] = &ConfigEntry[T]{ Data: data, @@ -166,7 +148,6 @@ func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID strin logging.Debug("Config cache SET for server ID: %s", serverID) } -// ServerConfigCache manages cached server configurations type ServerConfigCache struct { sync.RWMutex configuration map[string]*ConfigEntry[Configuration] @@ -177,7 +158,6 @@ type ServerConfigCache struct { config CacheConfig } -// NewServerConfigCache creates a new server configuration cache func NewServerConfigCache(config CacheConfig) *ServerConfigCache { logging.Debug("Initializing new ServerConfigCache with expiration time: %v, throttle time: %v", config.ExpirationTime, config.ThrottleTime) return &ServerConfigCache{ @@ -190,7 +170,6 @@ func NewServerConfigCache(config CacheConfig) *ServerConfigCache { } } -// GetConfiguration retrieves a cached configuration func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, bool) { c.RLock() defer c.RUnlock() @@ -198,7 +177,6 @@ func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, b return getConfigFromCache(c.configuration, serverID, c.config.ExpirationTime) } -// GetAssistRules retrieves cached assist rules func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) { c.RLock() defer c.RUnlock() @@ -206,7 +184,6 @@ func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) return getConfigFromCache(c.assistRules, serverID, c.config.ExpirationTime) } -// GetEvent retrieves cached event configuration func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) { c.RLock() defer c.RUnlock() @@ -214,7 +191,6 @@ func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) { return getConfigFromCache(c.event, serverID, c.config.ExpirationTime) } -// GetEventRules retrieves cached event rules func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) { c.RLock() defer c.RUnlock() @@ -222,7 +198,6 @@ func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) { return getConfigFromCache(c.eventRules, serverID, c.config.ExpirationTime) } -// GetSettings retrieves cached server settings func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) { c.RLock() defer c.RUnlock() @@ -230,7 +205,6 @@ func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) return getConfigFromCache(c.settings, serverID, c.config.ExpirationTime) } -// UpdateConfiguration updates the configuration cache func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configuration) { c.Lock() defer c.Unlock() @@ -238,7 +212,6 @@ func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configur updateConfigInCache(c.configuration, serverID, config) } -// UpdateAssistRules updates the assist rules cache func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules) { c.Lock() defer c.Unlock() @@ -246,7 +219,6 @@ func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules updateConfigInCache(c.assistRules, serverID, rules) } -// UpdateEvent updates the event configuration cache func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) { c.Lock() defer c.Unlock() @@ -254,7 +226,6 @@ func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) { updateConfigInCache(c.event, serverID, event) } -// UpdateEventRules updates the event rules cache func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) { c.Lock() defer c.Unlock() @@ -262,7 +233,6 @@ func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) updateConfigInCache(c.eventRules, serverID, rules) } -// UpdateSettings updates the server settings cache func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSettings) { c.Lock() defer c.Unlock() @@ -270,7 +240,6 @@ func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSetti updateConfigInCache(c.settings, serverID, settings) } -// InvalidateServerCache removes all cached configurations for a specific server func (c *ServerConfigCache) InvalidateServerCache(serverID string) { c.Lock() defer c.Unlock() @@ -283,7 +252,6 @@ func (c *ServerConfigCache) InvalidateServerCache(serverID string) { delete(c.settings, serverID) } -// Clear removes all entries from the cache func (c *ServerConfigCache) Clear() { c.Lock() defer c.Unlock() diff --git a/local/model/config.go b/local/model/config.go index 3833e78..8ea2306 100644 --- a/local/model/config.go +++ b/local/model/config.go @@ -13,17 +13,15 @@ import ( type IntString int type IntBool int -// Config tracks configuration modifications type Config struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"` - ConfigFile string `json:"configFile" gorm:"not null"` // e.g. "settings.json" + ConfigFile string `json:"configFile" gorm:"not null"` OldConfig string `json:"oldConfig" gorm:"type:text"` NewConfig string `json:"newConfig" gorm:"type:text"` ChangedAt time.Time `json:"changedAt" gorm:"default:CURRENT_TIMESTAMP"` } -// BeforeCreate is a GORM hook that runs before creating new config entries func (c *Config) BeforeCreate(tx *gorm.DB) error { if c.ID == uuid.Nil { c.ID = uuid.New() @@ -121,8 +119,6 @@ type Configuration struct { ConfigVersion IntString `json:"configVersion"` } -// Known configuration keys - func (i *IntBool) UnmarshalJSON(b []byte) error { var str int if err := json.Unmarshal(b, &str); err == nil && str <= 1 { diff --git a/local/model/filter.go b/local/model/filter.go index 13ef18c..b9df064 100644 --- a/local/model/filter.go +++ b/local/model/filter.go @@ -7,7 +7,6 @@ import ( "gorm.io/gorm" ) -// BaseFilter contains common filter fields that can be embedded in other filters type BaseFilter struct { Page int `query:"page"` PageSize int `query:"page_size"` @@ -15,18 +14,15 @@ type BaseFilter struct { SortDesc bool `query:"sort_desc"` } -// DateRangeFilter adds date range filtering capabilities type DateRangeFilter struct { StartDate time.Time `query:"start_date" time_format:"2006-01-02T15:04:05Z07:00"` EndDate time.Time `query:"end_date" time_format:"2006-01-02T15:04:05Z07:00"` } -// ServerBasedFilter adds server ID filtering capability type ServerBasedFilter struct { ServerID string `param:"id"` } -// ConfigFilter defines filtering options for Config queries type ConfigFilter struct { BaseFilter ServerBasedFilter @@ -34,13 +30,11 @@ type ConfigFilter struct { ChangedAt time.Time `query:"changed_at" time_format:"2006-01-02T15:04:05Z07:00"` } -// ApiFilter defines filtering options for Api queries type ServiceControlFilter struct { BaseFilter ServiceControl string `query:"serviceControl"` } -// MembershipFilter defines filtering options for User queries type MembershipFilter struct { BaseFilter Username string `query:"username"` @@ -48,36 +42,32 @@ type MembershipFilter struct { RoleID string `query:"role_id"` } -// Pagination returns the offset and limit for database queries func (f *BaseFilter) Pagination() (offset, limit int) { if f.Page < 1 { f.Page = 1 } if f.PageSize < 1 { - f.PageSize = 10 // Default page size + f.PageSize = 10 } offset = (f.Page - 1) * f.PageSize limit = f.PageSize return } -// GetSorting returns the sort field and direction for database queries func (f *BaseFilter) GetSorting() (field string, desc bool) { if f.SortBy == "" { - return "id", false // Default sorting + return "id", false } return f.SortBy, f.SortDesc } -// IsDateRangeValid checks if both dates are set and start date is before end date func (f *DateRangeFilter) IsDateRangeValid() bool { if f.StartDate.IsZero() || f.EndDate.IsZero() { - return true // If either date is not set, consider it valid + return true } return f.StartDate.Before(f.EndDate) } -// ApplyFilter applies the membership filter to a GORM query func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB { if f.Username != "" { query = query.Where("username LIKE ?", "%"+f.Username+"%") @@ -93,12 +83,10 @@ func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB { return query } -// Pagination returns the offset and limit for database queries func (f *MembershipFilter) Pagination() (offset, limit int) { return f.BaseFilter.Pagination() } -// GetSorting returns the sort field and direction for database queries func (f *MembershipFilter) GetSorting() (field string, desc bool) { return f.BaseFilter.GetSorting() } diff --git a/local/model/lookup.go b/local/model/lookup.go index 4f303cc..929a4f1 100644 --- a/local/model/lookup.go +++ b/local/model/lookup.go @@ -1,31 +1,26 @@ package model -// Track represents a track and its capacity type Track struct { Name string `json:"track" gorm:"primaryKey;size:50"` UniquePitBoxes int `json:"unique_pit_boxes"` PrivateServerSlots int `json:"private_server_slots"` } -// CarModel represents a car model mapping type CarModel struct { Value int `json:"value" gorm:"primaryKey"` CarModel string `json:"car_model"` } -// DriverCategory represents driver skill categories type DriverCategory struct { Value int `json:"value" gorm:"primaryKey"` Category string `json:"category"` } -// CupCategory represents championship cup categories type CupCategory struct { Value int `json:"value" gorm:"primaryKey"` Category string `json:"category"` } -// SessionType represents session types type SessionType struct { Value int `json:"value" gorm:"primaryKey"` SessionType string `json:"session_type"` diff --git a/local/model/model.go b/local/model/model.go index 3aadae4..8b13fd4 100644 --- a/local/model/model.go +++ b/local/model/model.go @@ -32,8 +32,6 @@ type BaseModel struct { DateUpdated time.Time `json:"dateUpdated"` } -// Init -// Initializes base model with DateCreated, DateUpdated, and Id values. func (cm *BaseModel) Init() { date := time.Now() cm.Id = uuid.NewString() diff --git a/local/model/permission.go b/local/model/permission.go index 3e3022d..88240f9 100644 --- a/local/model/permission.go +++ b/local/model/permission.go @@ -5,15 +5,13 @@ import ( "gorm.io/gorm" ) -// Permission represents an action that can be performed in the system. type Permission struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` Name string `json:"name" gorm:"unique_index;not null"` } -// BeforeCreate is a GORM hook that runs before creating new credentials func (s *Permission) BeforeCreate(tx *gorm.DB) error { s.ID = uuid.New() return nil -} \ No newline at end of file +} diff --git a/local/model/permissions.go b/local/model/permissions.go index b96eff0..d57c097 100644 --- a/local/model/permissions.go +++ b/local/model/permissions.go @@ -1,6 +1,5 @@ package model -// Permission constants const ( ServerView = "server.view" ServerCreate = "server.create" @@ -27,7 +26,6 @@ const ( MembershipEdit = "membership.edit" ) -// AllPermissions returns a slice of all permission strings. func AllPermissions() []string { return []string{ ServerView, diff --git a/local/model/role.go b/local/model/role.go index 9d48842..836393c 100644 --- a/local/model/role.go +++ b/local/model/role.go @@ -5,16 +5,14 @@ import ( "gorm.io/gorm" ) -// Role represents a user role in the system. type Role struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` Name string `json:"name" gorm:"unique_index;not null"` Permissions []Permission `json:"permissions" gorm:"many2many:role_permissions;"` } -// BeforeCreate is a GORM hook that runs before creating new credentials func (s *Role) BeforeCreate(tx *gorm.DB) error { s.ID = uuid.New() return nil -} \ No newline at end of file +} diff --git a/local/model/server.go b/local/model/server.go index 95a9e8c..bab75d6 100644 --- a/local/model/server.go +++ b/local/model/server.go @@ -16,7 +16,6 @@ const ( ServiceNamePrefix = "ACC-Server" ) -// Server represents an ACC server instance type ServerAPI struct { Name string `json:"name"` Status ServiceStatus `json:"status"` @@ -35,28 +34,27 @@ func (s *Server) ToServerAPI() *ServerAPI { } } -// Server represents an ACC server instance type Server struct { ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"` Name string `gorm:"not null" json:"name"` Status ServiceStatus `json:"status" gorm:"-"` IP string `gorm:"not null" json:"-"` Port int `gorm:"not null" json:"-"` - Path string `gorm:"not null" json:"path"` // e.g. "/acc/servers/server1/" - ServiceName string `gorm:"not null" json:"serviceName"` // Windows service name + Path string `gorm:"not null" json:"path"` + ServiceName string `gorm:"not null" json:"serviceName"` State *ServerState `gorm:"-" json:"state"` DateCreated time.Time `json:"dateCreated"` FromSteamCMD bool `gorm:"not null; default:true" json:"-"` } type PlayerState struct { - CarID int // Car ID in broadcast packets - DriverName string // Optional: pulled from registration packet + CarID int + DriverName string TeamName string CarModel string CurrentLap int - LastLapTime int // in milliseconds - BestLapTime int // in milliseconds + LastLapTime int + BestLapTime int Position int ConnectedAt time.Time DisconnectedAt *time.Time @@ -67,8 +65,6 @@ type State struct { Session string `json:"session"` SessionStart time.Time `json:"sessionStart"` PlayerCount int `json:"playerCount"` - // Players map[int]*PlayerState - // etc. } type ServerState struct { @@ -79,11 +75,8 @@ type ServerState struct { Track string `json:"track"` MaxConnections int `json:"maxConnections"` SessionDurationMinutes int `json:"sessionDurationMinutes"` - // Players map[int]*PlayerState - // etc. } -// ServerFilter defines filtering options for Server queries type ServerFilter struct { BaseFilter ServerBasedFilter @@ -92,9 +85,7 @@ type ServerFilter struct { Status string `query:"status"` } -// ApplyFilter implements the Filterable interface func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB { - // Apply server filter if f.ServerID != "" { if serverUUID, err := uuid.Parse(f.ServerID); err == nil { query = query.Where("id = ?", serverUUID) @@ -110,16 +101,13 @@ func (s *Server) GenerateUUID() { } } -// BeforeCreate is a GORM hook that runs before creating a new server func (s *Server) BeforeCreate(tx *gorm.DB) error { if s.Name == "" { return errors.New("server name is required") } - // Generate UUID if not set s.GenerateUUID() - // Generate service name and config path if not set if s.ServiceName == "" { s.ServiceName = s.GenerateServiceName() } @@ -127,7 +115,6 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error { s.Path = s.GenerateServerPath(BaseServerPath) } - // Set creation date if not set if s.DateCreated.IsZero() { s.DateCreated = time.Now().UTC() } @@ -135,19 +122,14 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error { return nil } -// GenerateServiceName creates a unique service name based on the server name func (s *Server) GenerateServiceName() string { - // If ID is set, use it if s.ID != uuid.Nil { return fmt.Sprintf("%s-%s", ServiceNamePrefix, s.ID.String()[:8]) } - // Otherwise use a timestamp-based unique identifier return fmt.Sprintf("%s-%d", ServiceNamePrefix, time.Now().UnixNano()) } -// GenerateServerPath creates the config path based on the service name func (s *Server) GenerateServerPath(steamCMDPath string) string { - // Ensure service name is set if s.ServiceName == "" { s.ServiceName = s.GenerateServiceName() } diff --git a/local/model/service_control.go b/local/model/service_control.go index c447294..835db6a 100644 --- a/local/model/service_control.go +++ b/local/model/service_control.go @@ -17,7 +17,6 @@ const ( StatusRunning ) -// String converts the ServiceStatus to its string representation func (s ServiceStatus) String() string { switch s { case StatusRunning: @@ -35,7 +34,6 @@ func (s ServiceStatus) String() string { } } -// ParseServiceStatus converts a string to ServiceStatus func ParseServiceStatus(s string) ServiceStatus { switch s { case "SERVICE_RUNNING": @@ -53,31 +51,24 @@ func ParseServiceStatus(s string) ServiceStatus { } } -// MarshalJSON implements json.Marshaler interface func (s ServiceStatus) MarshalJSON() ([]byte, error) { - // Return the numeric value instead of string return []byte(strconv.Itoa(int(s))), nil } -// UnmarshalJSON implements json.Unmarshaler interface func (s *ServiceStatus) UnmarshalJSON(data []byte) error { - // Try to parse as number first if i, err := strconv.Atoi(string(data)); err == nil { *s = ServiceStatus(i) return nil } - // Fallback to string parsing for backward compatibility str := string(data) if len(str) >= 2 { - // Remove quotes if present str = str[1 : len(str)-1] } *s = ParseServiceStatus(str) return nil } -// Scan implements the sql.Scanner interface func (s *ServiceStatus) Scan(value interface{}) error { if value == nil { *s = StatusUnknown @@ -99,7 +90,6 @@ func (s *ServiceStatus) Scan(value interface{}) error { } } -// Value implements the driver.Valuer interface func (s ServiceStatus) Value() (driver.Value, error) { return s.String(), nil } diff --git a/local/model/state_history.go b/local/model/state_history.go index 1579e32..c47bb17 100644 --- a/local/model/state_history.go +++ b/local/model/state_history.go @@ -10,27 +10,22 @@ import ( "gorm.io/gorm" ) -// StateHistoryFilter combines common filter capabilities type StateHistoryFilter struct { - ServerBasedFilter // Adds server ID from path parameter - DateRangeFilter // Adds date range filtering + ServerBasedFilter + DateRangeFilter - // Additional fields specific to state history Session TrackSession `query:"session"` MinPlayers *int `query:"min_players"` MaxPlayers *int `query:"max_players"` } -// ApplyFilter implements the Filterable interface func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB { - // Apply server filter if f.ServerID != "" { if serverUUID, err := uuid.Parse(f.ServerID); err == nil { query = query.Where("server_id = ?", serverUUID) } } - // Apply date range filter if set timeZero := time.Time{} if f.StartDate != timeZero { query = query.Where("date_created >= ?", f.StartDate) @@ -39,12 +34,10 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB { query = query.Where("date_created <= ?", f.EndDate) } - // Apply session filter if set if f.Session != "" { query = query.Where("session = ?", f.Session) } - // Apply player count filters if set if f.MinPlayers != nil { query = query.Where("player_count >= ?", *f.MinPlayers) } @@ -114,10 +107,9 @@ type StateHistory struct { DateCreated time.Time `json:"dateCreated"` SessionStart time.Time `json:"sessionStart"` SessionDurationMinutes int `json:"sessionDurationMinutes"` - SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` // Unique identifier for each session/event + SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` } -// BeforeCreate is a GORM hook that runs before creating new state history entries func (sh *StateHistory) BeforeCreate(tx *gorm.DB) error { if sh.ID == uuid.Nil { sh.ID = uuid.New() diff --git a/local/model/state_history_stats.go b/local/model/state_history_stats.go index 0a8e853..15c4fe3 100644 --- a/local/model/state_history_stats.go +++ b/local/model/state_history_stats.go @@ -21,7 +21,7 @@ type StateHistoryStats struct { AveragePlayers float64 `json:"averagePlayers"` PeakPlayers int `json:"peakPlayers"` TotalSessions int `json:"totalSessions"` - TotalPlaytime int `json:"totalPlaytime" gorm:"-"` // in minutes + TotalPlaytime int `json:"totalPlaytime" gorm:"-"` PlayerCountOverTime []PlayerCountPoint `json:"playerCountOverTime" gorm:"-"` SessionTypes []SessionCount `json:"sessionTypes" gorm:"-"` DailyActivity []DailyActivity `json:"dailyActivity" gorm:"-"` diff --git a/local/model/steam_credentials.go b/local/model/steam_credentials.go index 263a5d6..075d0e6 100644 --- a/local/model/steam_credentials.go +++ b/local/model/steam_credentials.go @@ -16,21 +16,18 @@ import ( "gorm.io/gorm" ) -// SteamCredentials represents stored Steam login credentials type SteamCredentials struct { ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"` Username string `gorm:"not null" json:"username"` - Password string `gorm:"not null" json:"-"` // Encrypted, not exposed in JSON + Password string `gorm:"not null" json:"-"` DateCreated time.Time `json:"dateCreated"` LastUpdated time.Time `json:"lastUpdated"` } -// TableName specifies the table name for GORM func (SteamCredentials) TableName() string { return "steam_credentials" } -// BeforeCreate is a GORM hook that runs before creating new credentials func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error { if s.ID == uuid.Nil { s.ID = uuid.New() @@ -42,7 +39,6 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error { } s.LastUpdated = now - // Encrypt password before saving encrypted, err := EncryptPassword(s.Password) if err != nil { return err @@ -52,11 +48,9 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error { return nil } -// BeforeUpdate is a GORM hook that runs before updating credentials func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error { s.LastUpdated = time.Now().UTC() - // Only encrypt if password field is being updated if tx.Statement.Changed("Password") { encrypted, err := EncryptPassword(s.Password) if err != nil { @@ -68,9 +62,7 @@ func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error { return nil } -// AfterFind is a GORM hook that runs after fetching credentials func (s *SteamCredentials) AfterFind(tx *gorm.DB) error { - // Decrypt password after fetching if s.Password != "" { decrypted, err := DecryptPassword(s.Password) if err != nil { @@ -81,18 +73,15 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error { return nil } -// Validate checks if the credentials are valid with enhanced security checks func (s *SteamCredentials) Validate() error { if s.Username == "" { return errors.New("username is required") } - // Enhanced username validation if len(s.Username) < 3 || len(s.Username) > 64 { return errors.New("username must be between 3 and 64 characters") } - // Check for valid characters in username (alphanumeric, underscore, hyphen) if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched { return errors.New("username contains invalid characters") } @@ -101,7 +90,6 @@ func (s *SteamCredentials) Validate() error { return errors.New("password is required") } - // Basic password validation if len(s.Password) < 6 { return errors.New("password must be at least 6 characters long") } @@ -110,7 +98,6 @@ func (s *SteamCredentials) Validate() error { return errors.New("password is too long") } - // Check for obvious weak passwords weakPasswords := []string{"password", "123456", "steam", "admin", "user"} lowerPass := strings.ToLower(s.Password) for _, weak := range weakPasswords { @@ -122,8 +109,6 @@ func (s *SteamCredentials) Validate() error { return nil } -// GetEncryptionKey returns the encryption key from config. -// The key is loaded from the ENCRYPTION_KEY environment variable. func GetEncryptionKey() []byte { key := []byte(configs.EncryptionKey) if len(key) != 32 { @@ -132,7 +117,6 @@ func GetEncryptionKey() []byte { return key } -// EncryptPassword encrypts a password using AES-256-GCM with enhanced security func EncryptPassword(password string) (string, error) { if password == "" { return "", errors.New("password cannot be empty") @@ -148,33 +132,27 @@ func EncryptPassword(password string) (string, error) { return "", err } - // Create a new GCM cipher gcm, err := cipher.NewGCM(block) if err != nil { return "", err } - // Create a cryptographically secure nonce nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return "", err } - // Encrypt the password with authenticated encryption ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil) - // Return base64 encoded encrypted password return base64.StdEncoding.EncodeToString(ciphertext), nil } -// DecryptPassword decrypts an encrypted password with enhanced validation func DecryptPassword(encryptedPassword string) (string, error) { if encryptedPassword == "" { return "", errors.New("encrypted password cannot be empty") } - // Validate base64 format - if len(encryptedPassword) < 24 { // Minimum reasonable length + if len(encryptedPassword) < 24 { return "", errors.New("invalid encrypted password format") } @@ -184,13 +162,11 @@ func DecryptPassword(encryptedPassword string) (string, error) { return "", err } - // Create a new GCM cipher gcm, err := cipher.NewGCM(block) if err != nil { return "", err } - // Decode base64 encoded password ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword) if err != nil { return "", errors.New("invalid base64 encoding") @@ -207,7 +183,6 @@ func DecryptPassword(encryptedPassword string) (string, error) { return "", errors.New("decryption failed - invalid ciphertext or key") } - // Validate decrypted content decrypted := string(plaintext) if len(decrypted) == 0 || len(decrypted) > 1024 { return "", errors.New("invalid decrypted password") diff --git a/local/model/user.go b/local/model/user.go index 7cc8adb..959e340 100644 --- a/local/model/user.go +++ b/local/model/user.go @@ -8,7 +8,6 @@ import ( "gorm.io/gorm" ) -// User represents a user account in the system. type User struct { ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"` Username string `json:"username" gorm:"unique_index;not null"` @@ -17,16 +16,13 @@ type User struct { Role Role `json:"role"` } -// BeforeCreate is a GORM hook that runs before creating new users func (s *User) BeforeCreate(tx *gorm.DB) error { s.ID = uuid.New() - // Validate password strength if err := password.ValidatePasswordStrength(s.Password); err != nil { return err } - // Hash password before saving hashed, err := password.HashPassword(s.Password) if err != nil { return err @@ -36,11 +32,8 @@ func (s *User) BeforeCreate(tx *gorm.DB) error { return nil } -// BeforeUpdate is a GORM hook that runs before updating users func (s *User) BeforeUpdate(tx *gorm.DB) error { - // Only hash if password field is being updated if tx.Statement.Changed("Password") { - // Validate password strength if err := password.ValidatePasswordStrength(s.Password); err != nil { return err } @@ -55,14 +48,10 @@ func (s *User) BeforeUpdate(tx *gorm.DB) error { return nil } -// AfterFind is a GORM hook that runs after fetching users func (s *User) AfterFind(tx *gorm.DB) error { - // Password remains hashed - never decrypt - // This hook is kept for potential future use return nil } -// Validate checks if the user data is valid func (s *User) Validate() error { if s.Username == "" { return errors.New("username is required") @@ -73,7 +62,6 @@ func (s *User) Validate() error { return nil } -// VerifyPassword verifies a plain text password against the stored hash func (s *User) VerifyPassword(plainPassword string) error { return password.VerifyPassword(s.Password, plainPassword) } diff --git a/local/model/websocket.go b/local/model/websocket.go index 3fd6b20..f894a9e 100644 --- a/local/model/websocket.go +++ b/local/model/websocket.go @@ -4,7 +4,6 @@ import ( "github.com/google/uuid" ) -// ServerCreationStep represents the steps in server creation process type ServerCreationStep string const ( @@ -18,7 +17,6 @@ const ( StepCompleted ServerCreationStep = "completed" ) -// StepStatus represents the status of a step type StepStatus string const ( @@ -28,7 +26,6 @@ const ( StatusFailed StepStatus = "failed" ) -// WebSocketMessageType represents different types of WebSocket messages type WebSocketMessageType string const ( @@ -38,7 +35,6 @@ const ( MessageTypeComplete WebSocketMessageType = "complete" ) -// WebSocketMessage is the base structure for all WebSocket messages type WebSocketMessage struct { Type WebSocketMessageType `json:"type"` ServerID *uuid.UUID `json:"server_id,omitempty"` @@ -46,7 +42,6 @@ type WebSocketMessage struct { Data interface{} `json:"data"` } -// StepMessage represents a step update message type StepMessage struct { Step ServerCreationStep `json:"step"` Status StepStatus `json:"status"` @@ -54,26 +49,22 @@ type StepMessage struct { Error string `json:"error,omitempty"` } -// SteamOutputMessage represents SteamCMD output type SteamOutputMessage struct { Output string `json:"output"` IsError bool `json:"is_error"` } -// ErrorMessage represents an error message type ErrorMessage struct { Error string `json:"error"` Details string `json:"details,omitempty"` } -// CompleteMessage represents completion message type CompleteMessage struct { ServerID uuid.UUID `json:"server_id"` Success bool `json:"success"` Message string `json:"message"` } -// GetStepDescription returns a human-readable description for each step func GetStepDescription(step ServerCreationStep) string { descriptions := map[ServerCreationStep]string{ StepValidation: "Validating server configuration", diff --git a/local/repository/base.go b/local/repository/base.go index efad80b..cae5ac4 100644 --- a/local/repository/base.go +++ b/local/repository/base.go @@ -7,13 +7,11 @@ import ( "gorm.io/gorm" ) -// BaseRepository provides generic CRUD operations for any model type BaseRepository[T any, F any] struct { db *gorm.DB modelType T } -// NewBaseRepository creates a new base repository for the given model type func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] { return &BaseRepository[T, F]{ db: db, @@ -21,23 +19,19 @@ func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] } } -// GetAll retrieves all records based on the filter func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, error) { result := new([]T) query := r.db.WithContext(ctx).Model(&r.modelType) - // Apply filter conditions if filter implements Filterable if filterable, ok := any(filter).(Filterable); ok { query = filterable.ApplyFilter(query) } - // Apply pagination if filter implements Pageable if pageable, ok := any(filter).(Pageable); ok { offset, limit := pageable.Pagination() query = query.Offset(offset).Limit(limit) } - // Apply sorting if filter implements Sortable if sortable, ok := any(filter).(Sortable); ok { field, desc := sortable.GetSorting() if desc { @@ -54,7 +48,6 @@ func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, err return result, nil } -// GetByID retrieves a single record by ID func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, error) { result := new(T) if err := r.db.WithContext(ctx).Where("id = ?", id).First(result).Error; err != nil { @@ -66,7 +59,6 @@ func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, return result, nil } -// Insert creates a new record func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error { if err := r.db.WithContext(ctx).Create(model).Error; err != nil { return fmt.Errorf("error creating record: %w", err) @@ -74,7 +66,6 @@ func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error { return nil } -// Update modifies an existing record func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error { if err := r.db.WithContext(ctx).Save(model).Error; err != nil { return fmt.Errorf("error updating record: %w", err) @@ -82,7 +73,6 @@ func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error { return nil } -// Delete removes a record by ID func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error { if err := r.db.WithContext(ctx).Delete(new(T), id).Error; err != nil { return fmt.Errorf("error deleting record: %w", err) @@ -90,7 +80,6 @@ func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error return nil } -// Count returns the total number of records matching the filter func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, error) { var count int64 query := r.db.WithContext(ctx).Model(&r.modelType) @@ -106,8 +95,6 @@ func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, err return count, nil } -// Interfaces for filter capabilities - type Filterable interface { ApplyFilter(*gorm.DB) *gorm.DB } diff --git a/local/repository/config.go b/local/repository/config.go index 1e4b1e7..e303365 100644 --- a/local/repository/config.go +++ b/local/repository/config.go @@ -17,13 +17,11 @@ func NewConfigRepository(db *gorm.DB) *ConfigRepository { } } -// UpdateConfig updates or creates a Config record func (r *ConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config { if err := r.Update(ctx, config); err != nil { - // If update fails, try to insert if err := r.Insert(ctx, config); err != nil { return nil } } return config -} \ No newline at end of file +} diff --git a/local/repository/membership.go b/local/repository/membership.go index ae9d000..f931ff5 100644 --- a/local/repository/membership.go +++ b/local/repository/membership.go @@ -8,20 +8,16 @@ import ( "gorm.io/gorm" ) -// MembershipRepository handles database operations for users, roles, and permissions. type MembershipRepository struct { *BaseRepository[model.User, model.MembershipFilter] } -// NewMembershipRepository creates a new MembershipRepository. func NewMembershipRepository(db *gorm.DB) *MembershipRepository { return &MembershipRepository{ BaseRepository: NewBaseRepository[model.User, model.MembershipFilter](db, model.User{}), } } -// FindUserByUsername finds a user by their username. -// It preloads the user's role and the role's permissions. func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username string) (*model.User, error) { var user model.User db := r.db.WithContext(ctx) @@ -32,7 +28,6 @@ func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username return &user, nil } -// FindUserByIDWithPermissions finds a user by their ID and preloads Role and Permissions. func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, userID string) (*model.User, error) { var user model.User db := r.db.WithContext(ctx) @@ -43,13 +38,11 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, return &user, nil } -// CreateUser creates a new user. func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error { db := r.db.WithContext(ctx) return db.Create(user).Error } -// FindRoleByName finds a role by its name. func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) (*model.Role, error) { var role model.Role db := r.db.WithContext(ctx) @@ -60,13 +53,11 @@ func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) return &role, nil } -// CreateRole creates a new role. func (r *MembershipRepository) CreateRole(ctx context.Context, role *model.Role) error { db := r.db.WithContext(ctx) return db.Create(role).Error } -// FindPermissionByName finds a permission by its name. func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name string) (*model.Permission, error) { var permission model.Permission db := r.db.WithContext(ctx) @@ -77,19 +68,16 @@ func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name st return &permission, nil } -// CreatePermission creates a new permission. func (r *MembershipRepository) CreatePermission(ctx context.Context, permission *model.Permission) error { db := r.db.WithContext(ctx) return db.Create(permission).Error } -// AssignPermissionsToRole assigns a set of permissions to a role. func (r *MembershipRepository) AssignPermissionsToRole(ctx context.Context, role *model.Role, permissions []model.Permission) error { db := r.db.WithContext(ctx) return db.Model(role).Association("Permissions").Replace(permissions) } -// GetUserPermissions retrieves all permissions for a given user ID. func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) { var user model.User db := r.db.WithContext(ctx) @@ -106,7 +94,6 @@ func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uu return permissions, nil } -// ListUsers retrieves all users. func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, error) { var users []*model.User db := r.db.WithContext(ctx) @@ -114,13 +101,11 @@ func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, er return users, err } -// DeleteUser deletes a user. func (r *MembershipRepository) DeleteUser(ctx context.Context, userID uuid.UUID) error { db := r.db.WithContext(ctx) return db.Delete(&model.User{}, "id = ?", userID).Error } -// FindUserByID finds a user by their ID. func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) { var user model.User db := r.db.WithContext(ctx) @@ -131,13 +116,11 @@ func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUI return &user, nil } -// UpdateUser updates a user's details in the database. func (r *MembershipRepository) UpdateUser(ctx context.Context, user *model.User) error { db := r.db.WithContext(ctx) return db.Save(user).Error } -// FindRoleByID finds a role by its ID. func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUID) (*model.Role, error) { var role model.Role db := r.db.WithContext(ctx) @@ -148,12 +131,10 @@ func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUI return &role, nil } -// ListUsersWithFilter retrieves users based on the membership filter. func (r *MembershipRepository) ListUsersWithFilter(ctx context.Context, filter *model.MembershipFilter) (*[]model.User, error) { return r.BaseRepository.GetAll(ctx, filter) } -// ListRoles retrieves all roles. func (r *MembershipRepository) ListRoles(ctx context.Context) ([]*model.Role, error) { var roles []*model.Role db := r.db.WithContext(ctx) diff --git a/local/repository/repository.go b/local/repository/repository.go index 664a968..8600daf 100644 --- a/local/repository/repository.go +++ b/local/repository/repository.go @@ -10,11 +10,7 @@ import ( "go.uber.org/dig" ) -// InitializeRepositories -// Initializes Dependency Injection modules for repositories -// -// Args: -// *dig.Container: Dig Container +// *dig.Container: Dig Container func InitializeRepositories(c *dig.Container) { c.Provide(NewServiceControlRepository) c.Provide(NewStateHistoryRepository) @@ -24,16 +20,14 @@ func InitializeRepositories(c *dig.Container) { c.Provide(NewSteamCredentialsRepository) c.Provide(NewMembershipRepository) - // Provide the Steam2FAManager as a singleton if err := c.Provide(func() *model.Steam2FAManager { manager := model.NewSteam2FAManager() - - // Use graceful shutdown manager for cleanup goroutine + shutdownManager := graceful.GetManager() shutdownManager.RunGoroutine(func(ctx context.Context) { ticker := time.NewTicker(15 * time.Minute) defer ticker.Stop() - + for { select { case <-ctx.Done(): @@ -43,7 +37,7 @@ func InitializeRepositories(c *dig.Container) { } } }) - + return manager }); err != nil { logging.Panic("unable to initialize steam 2fa manager") diff --git a/local/repository/server.go b/local/repository/server.go index c82566f..ffd318c 100644 --- a/local/repository/server.go +++ b/local/repository/server.go @@ -20,10 +20,6 @@ func NewServerRepository(db *gorm.DB) *ServerRepository { return repo } -// GetFirstByServiceName -// Gets first row from Server table. -// -// Args: // context.Context: Application context // Returns: // model.ServerModel: Server object from database. diff --git a/local/repository/state_history.go b/local/repository/state_history.go index de81458..e537182 100644 --- a/local/repository/state_history.go +++ b/local/repository/state_history.go @@ -18,17 +18,14 @@ func NewStateHistoryRepository(db *gorm.DB) *StateHistoryRepository { } } -// GetAll retrieves all state history records with the given filter func (r *StateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) { return r.BaseRepository.GetAll(ctx, filter) } -// Insert creates a new state history record func (r *StateHistoryRepository) Insert(ctx context.Context, model *model.StateHistory) error { return r.BaseRepository.Insert(ctx, model) } -// GetLastSessionID gets the last session ID for a server func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) { var lastSession model.StateHistory result := r.BaseRepository.db.WithContext(ctx). @@ -38,7 +35,7 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { - return uuid.Nil, nil // Return nil UUID if no sessions found + return uuid.Nil, nil } return uuid.Nil, result.Error } @@ -46,10 +43,8 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID return lastSession.SessionID, nil } -// GetSummaryStats calculates peak players, total sessions, and average players. func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) { var stats model.StateHistoryStats - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return model.StateHistoryStats{}, err @@ -73,12 +68,10 @@ func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *mo return stats, nil } -// GetTotalPlaytime calculates the total playtime in minutes. func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) { var totalPlaytime struct { TotalMinutes float64 } - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return 0, err @@ -100,10 +93,8 @@ func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *m return int(totalPlaytime.TotalMinutes), nil } -// GetPlayerCountOverTime gets downsampled player count data. func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) { var points []model.PlayerCountPoint - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return points, err @@ -122,10 +113,8 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil return points, err } -// GetSessionTypes counts sessions by type. func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) { var sessionTypes []model.SessionCount - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return sessionTypes, err @@ -145,10 +134,8 @@ func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *mo return sessionTypes, err } -// GetDailyActivity counts sessions per day. func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) { var dailyActivity []model.DailyActivity - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return dailyActivity, err @@ -167,10 +154,8 @@ func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *m return dailyActivity, err } -// GetRecentSessions retrieves the 10 most recent sessions. func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) { var recentSessions []model.RecentSession - // Parse ServerID to UUID for query serverUUID, err := uuid.Parse(filter.ServerID) if err != nil { return recentSessions, err diff --git a/local/service/config.go b/local/service/config.go index 0161c15..97144c5 100644 --- a/local/service/config.go +++ b/local/service/config.go @@ -77,8 +77,8 @@ func NewConfigService(repository *repository.ConfigRepository, serverRepository repository: repository, serverRepository: serverRepository, configCache: model.NewServerConfigCache(model.CacheConfig{ - ExpirationTime: 5 * time.Minute, // Cache configs for 5 minutes - ThrottleTime: 1 * time.Second, // Prevent rapid re-reads + ExpirationTime: 5 * time.Minute, + ThrottleTime: 1 * time.Second, DefaultStatus: model.StatusUnknown, }), } @@ -88,10 +88,6 @@ func (as *ConfigService) SetServerService(serverService *ServerService) { as.serverService = serverService } -// UpdateConfig -// Updates physical config file and caches it in database. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -103,7 +99,6 @@ func (as *ConfigService) UpdateConfig(ctx *fiber.Ctx, body *map[string]interface return as.updateConfigInternal(ctx.UserContext(), serverID, configFile, body, override) } -// updateConfigInternal handles the actual config update logic without Fiber dependencies func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID string, configFile string, body *map[string]interface{}, override bool) (*model.Config, error) { serverUUID, err := uuid.Parse(serverID) if err != nil { @@ -117,17 +112,14 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri return nil, fmt.Errorf("server not found") } - // Read existing config configPath := filepath.Join(server.GetConfigPath(), configFile) oldData, err := os.ReadFile(configPath) if err != nil { if os.IsNotExist(err) { - // Create directory if it doesn't exist dir := filepath.Dir(configPath) if err := os.MkdirAll(dir, 0755); err != nil { return nil, err } - // Create empty JSON file if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil { return nil, err } @@ -142,7 +134,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri return nil, err } - // Write new config newData, err := json.Marshal(&body) if err != nil { return nil, err @@ -168,12 +159,9 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri return nil, err } - // Invalidate all configs for this server since configs can be interdependent as.configCache.InvalidateServerCache(serverID) as.serverService.StartAccServerRuntime(server) - - // Log change return as.repository.UpdateConfig(ctx, &model.Config{ ServerID: serverUUID, ConfigFile: configFile, @@ -183,10 +171,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri }), nil } -// GetConfig -// Gets physical config file and caches it in database. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -202,7 +186,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { return nil, fiber.NewError(404, "Server not found") } - // Try to get from cache based on config file type switch configFile { case ConfigurationJson: if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok { @@ -233,7 +216,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { logging.Debug("Cache miss for server ID: %s, file: %s - loading from disk", serverIDStr, configFile) - // Not in cache, load from disk configPath := filepath.Join(server.GetConfigPath(), configFile) decoder := DecodeFileName(configFile) if decoder == nil { @@ -244,7 +226,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { if err != nil { if os.IsNotExist(err) { logging.Debug("Config file not found, creating default for server ID: %s, file: %s", serverIDStr, configFile) - // Return empty config if file doesn't exist switch configFile { case ConfigurationJson: return &model.Configuration{}, nil @@ -261,7 +242,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { return nil, err } - // Cache the loaded config switch configFile { case ConfigurationJson: as.configCache.UpdateConfiguration(serverIDStr, *config.(*model.Configuration)) @@ -279,8 +259,6 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) { return config, nil } -// GetConfigs -// Gets all configurations for a server, using cache when possible. func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, error) { serverID := ctx.Params("id") @@ -298,7 +276,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration logging.Info("Loading configs for server ID: %s at path: %s", serverIDStr, server.GetConfigPath()) configs := &model.Configurations{} - // Load configuration if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok { logging.Debug("Using cached configuration for server %s", serverIDStr) configs.Configuration = *cached @@ -313,7 +290,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration as.configCache.UpdateConfiguration(serverIDStr, config) } - // Load assist rules if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok { logging.Debug("Using cached assist rules for server %s", serverIDStr) configs.AssistRules = *cached @@ -328,7 +304,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration as.configCache.UpdateAssistRules(serverIDStr, rules) } - // Load event config if cached, ok := as.configCache.GetEvent(serverIDStr); ok { logging.Debug("Using cached event config for server %s", serverIDStr) configs.Event = *cached @@ -344,7 +319,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration as.configCache.UpdateEvent(serverIDStr, event) } - // Load event rules if cached, ok := as.configCache.GetEventRules(serverIDStr); ok { logging.Debug("Using cached event rules for server %s", serverIDStr) configs.EventRules = *cached @@ -359,7 +333,6 @@ func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configuration as.configCache.UpdateEventRules(serverIDStr, rules) } - // Load settings if cached, ok := as.configCache.GetSettings(serverIDStr); ok { logging.Debug("Using cached settings for server %s", serverIDStr) configs.Settings = *cached @@ -475,9 +448,7 @@ func (as *ConfigService) GetConfiguration(server *model.Server) (*model.Configur return &config, nil } -// SaveConfiguration saves the configuration for a server func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.Configuration) error { - // Convert config to map for UpdateConfig configMap := make(map[string]interface{}) configBytes, err := json.Marshal(config) if err != nil { @@ -487,7 +458,6 @@ func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.C return fmt.Errorf("failed to unmarshal configuration: %v", err) } - // Update the configuration using the internal method _, err = as.updateConfigInternal(context.Background(), server.ID.String(), ConfigurationJson, &configMap, true) return err } diff --git a/local/service/firewall_service.go b/local/service/firewall_service.go index e857462..0894a51 100644 --- a/local/service/firewall_service.go +++ b/local/service/firewall_service.go @@ -96,11 +96,9 @@ func (s *FirewallService) DeleteServerRules(serverName string, tcpPorts, udpPort } func (s *FirewallService) UpdateServerRules(serverName string, tcpPorts, udpPorts []int) error { - // First delete existing rules if err := s.DeleteServerRules(serverName, tcpPorts, udpPorts); err != nil { return err } - // Then create new rules return s.CreateServerRules(serverName, tcpPorts, udpPorts) -} \ No newline at end of file +} diff --git a/local/service/membership.go b/local/service/membership.go index 157d8fc..1e0f008 100644 --- a/local/service/membership.go +++ b/local/service/membership.go @@ -12,13 +12,11 @@ import ( "github.com/google/uuid" ) -// CacheInvalidator interface for cache invalidation type CacheInvalidator interface { InvalidateUserPermissions(userID string) InvalidateAllUserPermissions() } -// MembershipService provides business logic for membership-related operations. type MembershipService struct { repo *repository.MembershipRepository cacheInvalidator CacheInvalidator @@ -26,29 +24,25 @@ type MembershipService struct { openJwtHandler *jwt.OpenJWTHandler } -// NewMembershipService creates a new MembershipService. func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService { return &MembershipService{ repo: repo, - cacheInvalidator: nil, // Will be set later via SetCacheInvalidator + cacheInvalidator: nil, jwtHandler: jwtHandler, openJwtHandler: openJwtHandler, } } -// SetCacheInvalidator sets the cache invalidator after service initialization func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) { s.cacheInvalidator = invalidator } -// Login authenticates a user and returns a JWT. func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) { user, err := s.repo.FindUserByUsername(ctx, username) if err != nil { return nil, errors.New("invalid credentials") } - // Use secure password verification with constant-time comparison if err := user.VerifyPassword(password); err != nil { return nil, errors.New("invalid credentials") } @@ -56,7 +50,6 @@ func (s *MembershipService) HandleLogin(ctx context.Context, username, password return user, nil } -// Login authenticates a user and returns a JWT. func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) { user, err := s.HandleLogin(ctx, username, password) if err != nil { @@ -70,7 +63,6 @@ func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string return s.openJwtHandler.GenerateToken(userId) } -// CreateUser creates a new user. func (s *MembershipService) CreateUser(ctx context.Context, username, password, roleName string) (*model.User, error) { role, err := s.repo.FindRoleByName(ctx, roleName) @@ -94,43 +86,35 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password, return user, nil } -// ListUsers retrieves all users. func (s *MembershipService) ListUsers(ctx context.Context) ([]*model.User, error) { return s.repo.ListUsers(ctx) } -// GetUser retrieves a single user by ID. func (s *MembershipService) GetUser(ctx context.Context, userID uuid.UUID) (*model.User, error) { return s.repo.FindUserByID(ctx, userID) } -// GetUserWithPermissions retrieves a single user by ID with their role and permissions. func (s *MembershipService) GetUserWithPermissions(ctx context.Context, userID string) (*model.User, error) { return s.repo.FindUserByIDWithPermissions(ctx, userID) } -// UpdateUserRequest defines the request body for updating a user. type UpdateUserRequest struct { Username *string `json:"username"` Password *string `json:"password"` RoleID *uuid.UUID `json:"roleId"` } -// DeleteUser deletes a user with validation to prevent Super Admin deletion. func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) error { - // Get user with role information user, err := s.repo.FindUserByID(ctx, userID) if err != nil { return errors.New("user not found") } - // Get role to check if it's Super Admin role, err := s.repo.FindRoleByID(ctx, user.RoleID) if err != nil { return errors.New("user role not found") } - // Prevent deletion of Super Admin users if role.Name == "Super Admin" { return errors.New("cannot delete Super Admin user") } @@ -140,7 +124,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er return err } - // Invalidate cache for deleted user if s.cacheInvalidator != nil { s.cacheInvalidator.InvalidateUserPermissions(userID.String()) } @@ -149,7 +132,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er return nil } -// UpdateUser updates a user's details. func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, req UpdateUserRequest) (*model.User, error) { user, err := s.repo.FindUserByID(ctx, userID) if err != nil { @@ -161,12 +143,10 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re } if req.Password != nil && *req.Password != "" { - // Password will be automatically hashed in BeforeUpdate hook user.Password = *req.Password } if req.RoleID != nil { - // Check if role exists _, err := s.repo.FindRoleByID(ctx, *req.RoleID) if err != nil { return nil, errors.New("role not found") @@ -178,7 +158,6 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re return nil, err } - // Invalidate cache if role was changed if req.RoleID != nil && s.cacheInvalidator != nil { s.cacheInvalidator.InvalidateUserPermissions(userID.String()) } @@ -187,14 +166,12 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re return user, nil } -// HasPermission checks if a user has a specific permission. func (s *MembershipService) HasPermission(ctx context.Context, userID string, permissionName string) (bool, error) { user, err := s.repo.FindUserByIDWithPermissions(ctx, userID) if err != nil { return false, err } - // Super admin and Admin have all permissions if user.Role.Name == "Super Admin" || user.Role.Name == "Admin" { return true, nil } @@ -208,15 +185,13 @@ func (s *MembershipService) HasPermission(ctx context.Context, userID string, pe return false, nil } -// SetupInitialData creates the initial roles and permissions. func (s *MembershipService) SetupInitialData(ctx context.Context) error { - // Define all permissions permissions := model.AllPermissions() createdPermissions := make([]model.Permission, 0) for _, pName := range permissions { perm, err := s.repo.FindPermissionByName(ctx, pName) - if err != nil { // Assuming error means not found + if err != nil { perm = &model.Permission{Name: pName} if err := s.repo.CreatePermission(ctx, perm); err != nil { return err @@ -225,7 +200,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { createdPermissions = append(createdPermissions, *perm) } - // Create Super Admin role with all permissions superAdminRole, err := s.repo.FindRoleByName(ctx, "Super Admin") if err != nil { superAdminRole = &model.Role{Name: "Super Admin"} @@ -237,7 +211,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return err } - // Create Admin role with same permissions as Super Admin adminRole, err := s.repo.FindRoleByName(ctx, "Admin") if err != nil { adminRole = &model.Role{Name: "Admin"} @@ -249,7 +222,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return err } - // Create Manager role with limited permissions (excluding membership, role, user, server create/delete) managerRole, err := s.repo.FindRoleByName(ctx, "Manager") if err != nil { managerRole = &model.Role{Name: "Manager"} @@ -258,7 +230,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { } } - // Define manager permissions (limited set) managerPermissionNames := []string{ model.ServerView, model.ServerUpdate, @@ -282,16 +253,14 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return err } - // Invalidate all caches after role setup changes if s.cacheInvalidator != nil { s.cacheInvalidator.InvalidateAllUserPermissions() } - // Create a default admin user if one doesn't exist _, err = s.repo.FindUserByUsername(ctx, "admin") if err != nil { logging.Debug("Creating default admin user") - _, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed + _, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") if err != nil { return err } @@ -300,7 +269,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error { return nil } -// GetAllRoles retrieves all roles for dropdown selection. func (s *MembershipService) GetAllRoles(ctx context.Context) ([]*model.Role, error) { return s.repo.ListRoles(ctx) } diff --git a/local/service/server.go b/local/service/server.go index b193120..b7ac070 100644 --- a/local/service/server.go +++ b/local/service/server.go @@ -20,7 +20,7 @@ import ( const ( DefaultStartPort = 9600 - RequiredPortCount = 1 // Update this if ACC needs more ports + RequiredPortCount = 1 ) type ServerService struct { @@ -45,18 +45,15 @@ type pendingState struct { } func (s *ServerService) ensureLogTailing(server *model.Server, instance *tracking.AccServerInstance) { - // Check if we already have a tailer if _, exists := s.logTailers.Load(server.ID); exists { return } - // Start tailing in a goroutine that handles file creation/deletion go func() { logPath := filepath.Join(server.GetLogPath(), "server.log") tailer := tracking.NewLogTailer(logPath, instance.HandleLogLine) s.logTailers.Store(server.ID, tailer) - // Start tailing and automatically handle file changes tailer.Start() }() } @@ -82,7 +79,6 @@ func NewServerService( webSocketService: webSocketService, } - // Initialize server instances servers, err := repository.GetAll(context.Background(), &model.ServerFilter{}) if err != nil { logging.Error("Failed to get servers: %v", err) @@ -90,7 +86,6 @@ func NewServerService( } for i := range *servers { - // Initialize instance regardless of status logging.Info("Starting server runtime for server ID: %d", (*servers)[i].ID) service.StartAccServerRuntime(&(*servers)[i]) } @@ -99,7 +94,7 @@ func NewServerService( } func (s *ServerService) shouldInsertStateHistory(serverID uuid.UUID) bool { - insertInterval := 5 * time.Minute // Configure this as needed + insertInterval := 5 * time.Minute lastInsertInterface, exists := s.lastInsertTimes.Load(serverID) if !exists { @@ -122,16 +117,15 @@ func (s *ServerService) getNextSessionID(serverID uuid.UUID) uuid.UUID { lastID, err := s.stateHistoryRepo.GetLastSessionID(context.Background(), serverID) if err != nil { logging.Error("Failed to get last session ID for server %s: %v", serverID, err) - return uuid.New() // Return new UUID as fallback + return uuid.New() } if lastID == uuid.Nil { - return uuid.New() // Return new UUID if no previous session + return uuid.New() } - return uuid.New() // Always generate new UUID for each session + return uuid.New() } func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.ServerState) { - // Get or create session ID when session changes currentSessionInterface, exists := s.instances.Load(serverID) var sessionID uuid.UUID if !exists { @@ -163,7 +157,6 @@ func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.Serv } func (s *ServerService) updateSessionDuration(server *model.Server, sessionType model.TrackSession) { - // Get configs using helper methods event, err := s.configService.GetEventConfig(server) if err != nil { event = &model.EventConfig{} @@ -181,9 +174,7 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType serverInstance.State.Track = event.Track serverInstance.State.MaxConnections = configuration.MaxConnections.ToInt() - // Check if session type has changed if serverInstance.State.Session != sessionType { - // Get new session ID for the new session sessionID := s.getNextSessionID(server.ID) s.sessionIDs.Store(server.ID, sessionID) } @@ -204,33 +195,27 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType } func (s *ServerService) GenerateServerPath(server *model.Server) { - // Get the base steamcmd path from environment variable steamCMDPath := env.GetSteamCMDDirPath() - server.FromSteamCMD = true server.Path = server.GenerateServerPath(steamCMDPath) + server.FromSteamCMD = true } func (s *ServerService) handleStateChange(server *model.Server, state *model.ServerState) { - // Update session duration when session changes s.updateSessionDuration(server, state.Session) - // Invalidate status cache when server state changes s.apiService.statusCache.InvalidateStatus(server.ServiceName) - // Cancel existing timer if any if debouncer, exists := s.debouncers.Load(server.ID); exists { pending := debouncer.(*pendingState) pending.timer.Stop() } - // Create new timer timer := time.NewTimer(5 * time.Minute) s.debouncers.Store(server.ID, &pendingState{ timer: timer, state: state, }) - // Start goroutine to handle the delayed insert go func() { <-timer.C if debouncer, exists := s.debouncers.Load(server.ID); exists { @@ -240,14 +225,12 @@ func (s *ServerService) handleStateChange(server *model.Server, state *model.Ser } }() - // If enough time has passed since last insert, insert immediately if s.shouldInsertStateHistory(server.ID) { s.insertStateHistory(server.ID, state) } } func (s *ServerService) StartAccServerRuntime(server *model.Server) { - // Get or create instance instanceInterface, exists := s.instances.Load(server.ID) var instance *tracking.AccServerInstance if !exists { @@ -259,20 +242,14 @@ func (s *ServerService) StartAccServerRuntime(server *model.Server) { instance = instanceInterface.(*tracking.AccServerInstance) } - // Invalidate config cache for this server before loading new configs serverIDStr := server.ID.String() s.configService.configCache.InvalidateServerCache(serverIDStr) s.updateSessionDuration(server, instance.State.Session) - // Ensure log tailing is running (regardless of server status) s.ensureLogTailing(server, instance) } -// GetAll -// Gets All rows from Server table. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -304,10 +281,6 @@ func (s *ServerService) GetAll(ctx *fiber.Ctx, filter *model.ServerFilter) (*[]m return servers, nil } -// GetById -// Gets rows by ID from Server table. -// -// Args: // context.Context: Application context // Returns: // string: Application version @@ -334,22 +307,16 @@ func (as *ServerService) GetById(ctx *fiber.Ctx, serverID uuid.UUID) (*model.Ser return server, nil } -// CreateServerAsync starts server creation asynchronously and returns immediately func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) error { - // Perform basic validation first if err := server.Validate(); err != nil { return err } - // Generate server path s.GenerateServerPath(server) - // Create a background context that won't be cancelled when the HTTP request ends bgCtx := context.Background() - // Start the actual creation process in a goroutine go func() { - // Create server in background without using fiber.Ctx if err := s.createServerBackground(bgCtx, server); err != nil { logging.Error("Async server creation failed for server %s: %v", server.ID, err) s.webSocketService.BroadcastError(server.ID, "Server creation failed", err.Error()) @@ -361,11 +328,9 @@ func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) } func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error { - // Broadcast step: validation s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress, model.GetStepDescription(model.StepValidation), "") - // Validate basic server configuration if err := server.Validate(); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed, "", fmt.Sprintf("Validation failed: %v", err)) @@ -375,19 +340,15 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted, "Server configuration validated successfully", "") - // Broadcast step: directory creation s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress, model.GetStepDescription(model.StepDirectoryCreation), "") - // Directory creation is handled within InstallServer, so we mark it as completed s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted, "Server directories prepared", "") - // Broadcast step: Steam download s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress, model.GetStepDescription(model.StepSteamDownload), "") - // Install server using SteamCMD with streaming support if err := s.steamService.InstallServerWithWebSocket(ctx.UserContext(), server.Path, &server.ID, s.webSocketService); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed, "", fmt.Sprintf("Steam installation failed: %v", err)) @@ -397,11 +358,9 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted, "Server files downloaded successfully", "") - // Broadcast step: config generation s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress, model.GetStepDescription(model.StepConfigGeneration), "") - // Find available ports for server ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) if err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, @@ -409,10 +368,8 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error return fmt.Errorf("failed to find available ports: %v", err) } - // Use the first port for both TCP and UDP serverPort := ports[0] - // Update server configuration with the allocated port if err := s.updateServerPort(server, serverPort); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, "", fmt.Sprintf("Failed to update server configuration: %v", err)) @@ -422,17 +379,14 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted, fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "") - // Broadcast step: service creation s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress, model.GetStepDescription(model.StepServiceCreation), "") - // Create Windows service with correct paths execPath := filepath.Join(server.GetServerPath(), "accServer.exe") serverWorkingDir := filepath.Join(server.GetServerPath(), "server") if err := s.windowsService.CreateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed, "", fmt.Sprintf("Failed to create Windows service: %v", err)) - // Cleanup on failure s.steamService.UninstallServer(server.Path) return fmt.Errorf("failed to create Windows service: %v", err) } @@ -440,7 +394,6 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted, fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "") - // Broadcast step: firewall rules s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress, model.GetStepDescription(model.StepFirewallRules), "") @@ -450,7 +403,6 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed, "", fmt.Sprintf("Failed to create firewall rules: %v", err)) - // Cleanup on failure s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName) s.steamService.UninstallServer(server.Path) return fmt.Errorf("failed to create firewall rules: %v", err) @@ -459,15 +411,12 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted, fmt.Sprintf("Firewall rules created for port %d", serverPort), "") - // Broadcast step: database save s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress, model.GetStepDescription(model.StepDatabaseSave), "") - // Insert server into database if err := s.repository.Insert(ctx.UserContext(), server); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed, "", fmt.Sprintf("Failed to save server to database: %v", err)) - // Cleanup on failure s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts) s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName) s.steamService.UninstallServer(server.Path) @@ -477,10 +426,8 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted, "Server saved to database successfully", "") - // Initialize server runtime s.StartAccServerRuntime(server) - // Broadcast completion s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted, model.GetStepDescription(model.StepCompleted), "") @@ -490,13 +437,10 @@ func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error return nil } -// createServerBackground performs server creation in background without fiber.Ctx func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error { - // Broadcast step: validation s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusInProgress, model.GetStepDescription(model.StepValidation), "") - // Validate basic server configuration (already done in async method, but double-check) if err := server.Validate(); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusFailed, "", fmt.Sprintf("Validation failed: %v", err)) @@ -506,20 +450,16 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode s.webSocketService.BroadcastStep(server.ID, model.StepValidation, model.StatusCompleted, "Server configuration validated successfully", "") - // Broadcast step: directory creation s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusInProgress, model.GetStepDescription(model.StepDirectoryCreation), "") - // Directory creation is handled within InstallServer, so we mark it as completed s.webSocketService.BroadcastStep(server.ID, model.StepDirectoryCreation, model.StatusCompleted, "Server directories prepared", "") - // Broadcast step: Steam download s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusInProgress, model.GetStepDescription(model.StepSteamDownload), "") - // Install server using SteamCMD with streaming support - if err := s.steamService.InstallServerWithWebSocket(ctx, server.GetServerPath(), &server.ID, s.webSocketService); err != nil { + if err := s.steamService.InstallServerWithWebSocket(ctx, server.Path, &server.ID, s.webSocketService); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusFailed, "", fmt.Sprintf("Steam installation failed: %v", err)) return fmt.Errorf("failed to install server: %v", err) @@ -528,11 +468,9 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode s.webSocketService.BroadcastStep(server.ID, model.StepSteamDownload, model.StatusCompleted, "Server files downloaded successfully", "") - // Broadcast step: config generation s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusInProgress, model.GetStepDescription(model.StepConfigGeneration), "") - // Find available ports for server ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) if err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, @@ -540,10 +478,8 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode return fmt.Errorf("failed to find available ports: %v", err) } - // Use the first port for both TCP and UDP serverPort := ports[0] - // Update server configuration with the allocated port if err := s.updateServerPort(server, serverPort); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusFailed, "", fmt.Sprintf("Failed to update server configuration: %v", err)) @@ -553,17 +489,14 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode s.webSocketService.BroadcastStep(server.ID, model.StepConfigGeneration, model.StatusCompleted, fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), "") - // Broadcast step: service creation s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusInProgress, model.GetStepDescription(model.StepServiceCreation), "") - // Create Windows service with correct paths execPath := filepath.Join(server.GetServerPath(), "accServer.exe") serverWorkingDir := filepath.Join(server.GetServerPath(), "server") if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusFailed, "", fmt.Sprintf("Failed to create Windows service: %v", err)) - // Cleanup on failure s.steamService.UninstallServer(server.Path) return fmt.Errorf("failed to create Windows service: %v", err) } @@ -571,7 +504,6 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode s.webSocketService.BroadcastStep(server.ID, model.StepServiceCreation, model.StatusCompleted, fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), "") - // Broadcast step: firewall rules s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusInProgress, model.GetStepDescription(model.StepFirewallRules), "") @@ -581,7 +513,6 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusFailed, "", fmt.Sprintf("Failed to create firewall rules: %v", err)) - // Cleanup on failure s.windowsService.DeleteService(ctx, server.ServiceName) s.steamService.UninstallServer(server.Path) return fmt.Errorf("failed to create firewall rules: %v", err) @@ -590,15 +521,12 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode s.webSocketService.BroadcastStep(server.ID, model.StepFirewallRules, model.StatusCompleted, fmt.Sprintf("Firewall rules created for port %d", serverPort), "") - // Broadcast step: database save s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusInProgress, model.GetStepDescription(model.StepDatabaseSave), "") - // Insert server into database if err := s.repository.Insert(ctx, server); err != nil { s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusFailed, "", fmt.Sprintf("Failed to save server to database: %v", err)) - // Cleanup on failure s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts) s.windowsService.DeleteService(ctx, server.ServiceName) s.steamService.UninstallServer(server.Path) @@ -608,10 +536,8 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode s.webSocketService.BroadcastStep(server.ID, model.StepDatabaseSave, model.StatusCompleted, "Server saved to database successfully", "") - // Initialize server runtime s.StartAccServerRuntime(server) - // Broadcast completion s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted, model.GetStepDescription(model.StepCompleted), "") @@ -622,18 +548,15 @@ func (s *ServerService) createServerBackground(ctx context.Context, server *mode } func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error { - // Get server details server, err := s.repository.GetByID(ctx.UserContext(), serverID) if err != nil { return fmt.Errorf("failed to get server details: %v", err) } - // Stop and remove Windows service if err := s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName); err != nil { logging.Error("Failed to delete Windows service: %v", err) } - // Remove firewall rules configuration, err := s.configService.GetConfiguration(server) if err != nil { logging.Error("Failed to get configuration for server %d: %v", server.ID, err) @@ -644,17 +567,14 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error { logging.Error("Failed to delete firewall rules: %v", err) } - // Uninstall server files if err := s.steamService.UninstallServer(server.Path); err != nil { logging.Error("Failed to uninstall server: %v", err) } - // Remove from database if err := s.repository.Delete(ctx.UserContext(), serverID); err != nil { return fmt.Errorf("failed to delete server from database: %v", err) } - // Cleanup runtime resources if tailer, exists := s.logTailers.Load(server.ID); exists { tailer.(*tracking.LogTailer).Stop() s.logTailers.Delete(server.ID) @@ -664,84 +584,27 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error { s.debouncers.Delete(server.ID) s.sessionIDs.Delete(server.ID) - // Invalidate status cache for deleted server s.apiService.statusCache.InvalidateStatus(server.ServiceName) return nil } -func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error { - // Validate server configuration - if err := server.Validate(); err != nil { - return err - } - - // Get existing server details - existingServer, err := s.repository.GetByID(ctx.UserContext(), server.ID) - if err != nil { - return fmt.Errorf("failed to get existing server details: %v", err) - } - - // Update server files if path changed - if existingServer.Path != server.Path { - if err := s.steamService.InstallServer(ctx.UserContext(), server.Path, &server.ID); err != nil { - return fmt.Errorf("failed to install server to new location: %v", err) - } - // Clean up old installation - if err := s.steamService.UninstallServer(existingServer.Path); err != nil { - logging.Error("Failed to remove old server installation: %v", err) - } - } - - // Update Windows service if necessary - if existingServer.ServiceName != server.ServiceName || existingServer.Path != server.Path { - execPath := filepath.Join(server.GetServerPath(), "accServer.exe") - serverWorkingDir := server.GetServerPath() - if err := s.windowsService.UpdateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil { - return fmt.Errorf("failed to update Windows service: %v", err) - } - } - - // Update firewall rules if service name changed - if existingServer.ServiceName != server.ServiceName { - if err := s.configureFirewall(server); err != nil { - return fmt.Errorf("failed to update firewall rules: %v", err) - } - // Invalidate cache for old service name - s.apiService.statusCache.InvalidateStatus(existingServer.ServiceName) - } - - // Update database record - if err := s.repository.Update(ctx.UserContext(), server); err != nil { - return fmt.Errorf("failed to update server in database: %v", err) - } - - // Restart server runtime - s.StartAccServerRuntime(server) - - return nil -} - func (s *ServerService) configureFirewall(server *model.Server) error { - // Find available ports for the server ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount) if err != nil { return fmt.Errorf("failed to find available ports: %v", err) } - // Use the first port for both TCP and UDP serverPort := ports[0] tcpPorts := []int{serverPort} udpPorts := []int{serverPort} logging.Info("Configuring firewall for server %d with port %d", server.ID, serverPort) - // Configure firewall rules if err := s.firewallService.UpdateServerRules(server.Name, tcpPorts, udpPorts); err != nil { return fmt.Errorf("failed to configure firewall: %v", err) } - // Update server configuration with the allocated port if err := s.updateServerPort(server, serverPort); err != nil { return fmt.Errorf("failed to update server configuration: %v", err) } @@ -750,7 +613,6 @@ func (s *ServerService) configureFirewall(server *model.Server) error { } func (s *ServerService) updateServerPort(server *model.Server, port int) error { - // Load current configuration config, err := s.configService.GetConfiguration(server) if err != nil { return fmt.Errorf("failed to load server configuration: %v", err) @@ -759,7 +621,6 @@ func (s *ServerService) updateServerPort(server *model.Server, port int) error { config.TcpPort = model.IntString(port) config.UdpPort = model.IntString(port) - // Save the updated configuration if err := s.configService.SaveConfiguration(server, config); err != nil { return fmt.Errorf("failed to save server configuration: %v", err) } diff --git a/local/service/service.go b/local/service/service.go index b368cf3..46df608 100644 --- a/local/service/service.go +++ b/local/service/service.go @@ -7,17 +7,12 @@ import ( "go.uber.org/dig" ) -// InitializeServices -// Initializes Dependency Injection modules for services -// -// Args: -// *dig.Container: Dig Container +// *dig.Container: Dig Container func InitializeServices(c *dig.Container) { logging.Debug("Initializing repositories") repository.InitializeRepositories(c) logging.Debug("Registering services") - // Provide services c.Provide(NewSteamService) c.Provide(NewServerService) c.Provide(NewStateHistoryService) diff --git a/local/service/service_control.go b/local/service/service_control.go index c284959..b1794d5 100644 --- a/local/service/service_control.go +++ b/local/service/service_control.go @@ -24,9 +24,9 @@ func NewServiceControlService(repository *repository.ServiceControlRepository, repository: repository, serverRepository: serverRepository, statusCache: model.NewServerStatusCache(model.CacheConfig{ - ExpirationTime: 30 * time.Second, // Cache expires after 30 seconds - ThrottleTime: 5 * time.Second, // Minimum 5 seconds between checks - DefaultStatus: model.StatusRunning, // Default to running if throttled + ExpirationTime: 30 * time.Second, + ThrottleTime: 5 * time.Second, + DefaultStatus: model.StatusRunning, }), windowsService: NewWindowsService(), } @@ -42,18 +42,15 @@ func (as *ServiceControlService) GetStatus(ctx *fiber.Ctx) (string, error) { return "", err } - // Try to get status from cache if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck { return status.String(), nil } - // If cache miss or expired, check actual status statusStr, err := as.StatusServer(serviceName) if err != nil { return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -65,7 +62,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri return "", err } - // Update status cache for this service before starting as.statusCache.UpdateStatus(serviceName, model.StatusStarting) _, err = as.StartServer(serviceName) @@ -77,7 +73,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -89,7 +84,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin return "", err } - // Update status cache for this service before stopping as.statusCache.UpdateStatus(serviceName, model.StatusStopping) _, err = as.StopServer(serviceName) @@ -101,7 +95,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -113,7 +106,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st return "", err } - // Update status cache for this service before restarting as.statusCache.UpdateStatus(serviceName, model.StatusRestarting) _, err = as.RestartServer(serviceName) @@ -125,7 +117,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil @@ -135,20 +126,16 @@ func (as *ServiceControlService) StatusServer(serviceName string) (string, error return as.windowsService.Status(context.Background(), serviceName) } -// GetCachedStatus gets the cached status for a service name without requiring fiber context func (as *ServiceControlService) GetCachedStatus(serviceName string) (string, error) { - // Try to get status from cache if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck { return status.String(), nil } - // If cache miss or expired, check actual status statusStr, err := as.StatusServer(serviceName) if err != nil { return "", err } - // Parse and update cache with new status status := model.ParseServiceStatus(statusStr) as.statusCache.UpdateStatus(serviceName, status) return status.String(), nil diff --git a/local/service/service_manager.go b/local/service/service_manager.go index 8de9965..13d1cad 100644 --- a/local/service/service_manager.go +++ b/local/service/service_manager.go @@ -6,7 +6,7 @@ import ( ) type ServiceManager struct { - executor *command.CommandExecutor + executor *command.CommandExecutor psExecutor *command.CommandExecutor } @@ -24,17 +24,14 @@ func NewServiceManager() *ServiceManager { } func (s *ServiceManager) ManageService(serviceName, action string) (string, error) { - // Run NSSM command through PowerShell to ensure elevation output, err := s.psExecutor.ExecuteWithOutput("-nologo", "-noprofile", ".\\nssm", action, serviceName) if err != nil { return "", err } - // Clean up output by removing null bytes and trimming whitespace cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", "")) - // Remove \r\n from status strings cleaned = strings.TrimSuffix(cleaned, "\r\n") - + return cleaned, nil } @@ -51,11 +48,9 @@ func (s *ServiceManager) Stop(serviceName string) (string, error) { } func (s *ServiceManager) Restart(serviceName string) (string, error) { - // First stop the service if _, err := s.Stop(serviceName); err != nil { return "", err } - // Then start it again return s.Start(serviceName) -} \ No newline at end of file +} diff --git a/local/service/state_history.go b/local/service/state_history.go index 29f7735..5bdd441 100644 --- a/local/service/state_history.go +++ b/local/service/state_history.go @@ -46,7 +46,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH eg, gCtx := errgroup.WithContext(ctx.UserContext()) - // Get Summary Stats (Peak/Avg Players, Total Sessions) eg.Go(func() error { summary, err := s.repository.GetSummaryStats(gCtx, filter) if err != nil { @@ -61,7 +60,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Total Playtime eg.Go(func() error { playtime, err := s.repository.GetTotalPlaytime(gCtx, filter) if err != nil { @@ -74,7 +72,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Player Count Over Time eg.Go(func() error { playerCount, err := s.repository.GetPlayerCountOverTime(gCtx, filter) if err != nil { @@ -87,7 +84,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Session Types eg.Go(func() error { sessionTypes, err := s.repository.GetSessionTypes(gCtx, filter) if err != nil { @@ -100,7 +96,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Daily Activity eg.Go(func() error { dailyActivity, err := s.repository.GetDailyActivity(gCtx, filter) if err != nil { @@ -113,7 +108,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH return nil }) - // Get Recent Sessions eg.Go(func() error { recentSessions, err := s.repository.GetRecentSessions(gCtx, filter) if err != nil { diff --git a/local/service/steam_service.go b/local/service/steam_service.go index ba0ea83..f6ef459 100644 --- a/local/service/steam_service.go +++ b/local/service/steam_service.go @@ -22,12 +22,11 @@ const ( ) type SteamService struct { - executor *command.CommandExecutor - interactiveExecutor *command.InteractiveCommandExecutor - repository *repository.SteamCredentialsRepository - tfaManager *model.Steam2FAManager - pathValidator *security.PathValidator - downloadVerifier *security.DownloadVerifier + executor *command.CommandExecutor + repository *repository.SteamCredentialsRepository + tfaManager *model.Steam2FAManager + pathValidator *security.PathValidator + downloadVerifier *security.DownloadVerifier } func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService { @@ -37,12 +36,11 @@ func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManag } return &SteamService{ - executor: baseExecutor, - interactiveExecutor: command.NewInteractiveCommandExecutor(baseExecutor, tfaManager), - repository: repository, - tfaManager: tfaManager, - pathValidator: security.NewPathValidator(), - downloadVerifier: security.NewDownloadVerifier(), + executor: baseExecutor, + repository: repository, + tfaManager: tfaManager, + pathValidator: security.NewPathValidator(), + downloadVerifier: security.NewDownloadVerifier(), } } @@ -58,21 +56,17 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr } func (s *SteamService) ensureSteamCMD(_ context.Context) error { - // Get SteamCMD path from environment variable steamCMDPath := env.GetSteamCMDPath() steamCMDDir := filepath.Dir(steamCMDPath) - // Check if SteamCMD exists if _, err := os.Stat(steamCMDPath); !os.IsNotExist(err) { return nil } - // Create directory if it doesn't exist if err := os.MkdirAll(steamCMDDir, 0755); err != nil { return fmt.Errorf("failed to create SteamCMD directory: %v", err) } - // Download and install SteamCMD securely logging.Info("Downloading SteamCMD...") steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip") if err := s.downloadVerifier.VerifyAndDownload( @@ -82,150 +76,27 @@ func (s *SteamService) ensureSteamCMD(_ context.Context) error { return fmt.Errorf("failed to download SteamCMD: %v", err) } - // Extract SteamCMD logging.Info("Extracting SteamCMD...") if err := s.executor.Execute("-Command", fmt.Sprintf("Expand-Archive -Path 'steamcmd.zip' -DestinationPath '%s'", steamCMDDir)); err != nil { return fmt.Errorf("failed to extract SteamCMD: %v", err) } - // Clean up zip file os.Remove("steamcmd.zip") return nil } -func (s *SteamService) InstallServer(ctx context.Context, installPath string, serverID *uuid.UUID) error { - if err := s.ensureSteamCMD(ctx); err != nil { - return err - } - - // Validate installation path for security - if err := s.pathValidator.ValidateInstallPath(installPath); err != nil { - return fmt.Errorf("invalid installation path: %v", err) - } - - // Convert to absolute path and ensure proper Windows path format - absPath, err := filepath.Abs(installPath) - if err != nil { - return fmt.Errorf("failed to get absolute path: %v", err) - } - absPath = filepath.Clean(absPath) - - // Ensure install path exists - if err := os.MkdirAll(absPath, 0755); err != nil { - return fmt.Errorf("failed to create install directory: %v", err) - } - - // Get Steam credentials - creds, err := s.GetCredentials(ctx) - if err != nil { - return fmt.Errorf("failed to get Steam credentials: %v", err) - } - - // Get SteamCMD path from environment variable - steamCMDPath := env.GetSteamCMDPath() - - // Build SteamCMD command arguments - steamCMDArgs := []string{ - "+force_install_dir", absPath, - "+login", - } - - if creds != nil && creds.Username != "" { - logging.Info("Using Steam credentials for user: %s", creds.Username) - steamCMDArgs = append(steamCMDArgs, creds.Username) - if creds.Password != "" { - steamCMDArgs = append(steamCMDArgs, creds.Password) - } - } else { - logging.Info("Using anonymous Steam login") - steamCMDArgs = append(steamCMDArgs, "anonymous") - } - - steamCMDArgs = append(steamCMDArgs, - "+app_update", ACCServerAppID, - "validate", - "+quit", - ) - - // Execute SteamCMD directly without PowerShell wrapper to get better output capture - args := steamCMDArgs - - // Use interactive executor to handle potential 2FA prompts with timeout - logging.Info("Installing ACC server to %s...", absPath) - logging.Info("SteamCMD command: %s %s", steamCMDPath, strings.Join(args, " ")) - - // Create a context with timeout to prevent hanging indefinitely - timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) // Increased timeout - defer cancel() - - // Update the executor to use SteamCMD directly - originalExePath := s.interactiveExecutor.ExePath - s.interactiveExecutor.ExePath = steamCMDPath - defer func() { - s.interactiveExecutor.ExePath = originalExePath - }() - - if err := s.interactiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { - logging.Error("SteamCMD execution failed: %v", err) - if timeoutCtx.Err() == context.DeadlineExceeded { - return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required") - } - return fmt.Errorf("failed to run SteamCMD: %v", err) - } - - logging.Info("SteamCMD execution completed successfully, proceeding with verification...") - - // Add a delay to allow Steam to properly cleanup - logging.Info("Waiting for Steam operations to complete...") - time.Sleep(5 * time.Second) - - // Verify installation - exePath := filepath.Join(absPath, "server", "accServer.exe") - logging.Info("Checking for ACC server executable at: %s", exePath) - - if _, err := os.Stat(exePath); os.IsNotExist(err) { - // Log directory contents to help debug - logging.Info("accServer.exe not found, checking directory contents...") - if entries, dirErr := os.ReadDir(absPath); dirErr == nil { - logging.Info("Contents of %s:", absPath) - for _, entry := range entries { - logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir()) - } - } - - // Check if there's a server subdirectory - serverDir := filepath.Join(absPath, "server") - if entries, dirErr := os.ReadDir(serverDir); dirErr == nil { - logging.Info("Contents of %s:", serverDir) - for _, entry := range entries { - logging.Info(" - %s (dir: %v)", entry.Name(), entry.IsDir()) - } - } else { - logging.Info("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr) - } - - return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath) - } - - logging.Info("Server installation completed successfully - accServer.exe found at %s", exePath) - return nil -} - -// InstallServerWithWebSocket installs a server with WebSocket output streaming func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPath string, serverID *uuid.UUID, wsService *WebSocketService) error { if err := s.ensureSteamCMD(ctx); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true) return err } - // Validate installation path for security if err := s.pathValidator.ValidateInstallPath(installPath); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true) return fmt.Errorf("invalid installation path: %v", err) } - // Convert to absolute path and ensure proper Windows path format absPath, err := filepath.Abs(installPath) if err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true) @@ -233,7 +104,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa } absPath = filepath.Clean(absPath) - // Ensure install path exists if err := os.MkdirAll(absPath, 0755); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true) return fmt.Errorf("failed to create install directory: %v", err) @@ -241,17 +111,14 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false) - // Get Steam credentials creds, err := s.GetCredentials(ctx) if err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true) return fmt.Errorf("failed to get Steam credentials: %v", err) } - // Get SteamCMD path from environment variable steamCMDPath := env.GetSteamCMDPath() - // Build SteamCMD command arguments steamCMDArgs := []string{ "+force_install_dir", absPath, "+login", @@ -274,27 +141,32 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa "+quit", ) - // Execute SteamCMD with WebSocket output streaming args := steamCMDArgs wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false) - // Create a context with timeout to prevent hanging indefinitely timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) defer cancel() - // Update the executor to use SteamCMD directly - originalExePath := s.interactiveExecutor.ExePath - s.interactiveExecutor.ExePath = steamCMDPath - defer func() { - s.interactiveExecutor.ExePath = originalExePath - }() + callbackConfig := &command.CallbackConfig{ + OnOutput: func(serverID uuid.UUID, output string, isError bool) { + wsService.BroadcastSteamOutput(serverID, output, isError) + }, + OnCommand: func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) { + if completed { + if success { + wsService.BroadcastSteamOutput(serverID, "Command completed successfully", false) + } else { + wsService.BroadcastSteamOutput(serverID, fmt.Sprintf("Command failed: %s", error), true) + } + } + }, + } - // Create a modified interactive executor that streams output to WebSocket - wsInteractiveExecutor := command.NewInteractiveCommandExecutorWithWebSocket(s.executor, s.tfaManager, wsService, *serverID) - wsInteractiveExecutor.ExePath = steamCMDPath + callbackInteractiveExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbackConfig, *serverID) + callbackInteractiveExecutor.ExePath = steamCMDPath - if err := wsInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { + if err := callbackInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true) if timeoutCtx.Err() == context.DeadlineExceeded { return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required") @@ -304,11 +176,9 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa wsService.BroadcastSteamOutput(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false) - // Add a delay to allow Steam to properly cleanup wsService.BroadcastSteamOutput(*serverID, "Waiting for Steam operations to complete...", false) time.Sleep(5 * time.Second) - // Verify installation exePath := filepath.Join(absPath, "server", "accServer.exe") wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false) @@ -322,7 +192,6 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa } } - // Check if there's a server subdirectory serverDir := filepath.Join(absPath, "server") if entries, dirErr := os.ReadDir(serverDir); dirErr == nil { wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false) @@ -341,8 +210,117 @@ func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPa return nil } -func (s *SteamService) UpdateServer(ctx context.Context, installPath string, serverID *uuid.UUID) error { - return s.InstallServer(ctx, installPath, serverID) // Same process as install +func (s *SteamService) InstallServerWithCallbacks(ctx context.Context, installPath string, serverID *uuid.UUID, outputCallback command.OutputCallback) error { + if err := s.ensureSteamCMD(ctx); err != nil { + outputCallback(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true) + return err + } + + if err := s.pathValidator.ValidateInstallPath(installPath); err != nil { + outputCallback(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true) + return fmt.Errorf("invalid installation path: %v", err) + } + + absPath, err := filepath.Abs(installPath) + if err != nil { + outputCallback(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true) + return fmt.Errorf("failed to get absolute path: %v", err) + } + absPath = filepath.Clean(absPath) + + if err := os.MkdirAll(absPath, 0755); err != nil { + outputCallback(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true) + return fmt.Errorf("failed to create install directory: %v", err) + } + + outputCallback(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false) + + creds, err := s.GetCredentials(ctx) + if err != nil { + outputCallback(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true) + return fmt.Errorf("failed to get Steam credentials: %v", err) + } + + steamCMDPath := env.GetSteamCMDPath() + + steamCMDArgs := []string{ + "+force_install_dir", absPath, + "+login", + } + + if creds != nil && creds.Username != "" { + outputCallback(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false) + steamCMDArgs = append(steamCMDArgs, creds.Username) + if creds.Password != "" { + steamCMDArgs = append(steamCMDArgs, creds.Password) + } + } else { + outputCallback(*serverID, "Using anonymous Steam login", false) + steamCMDArgs = append(steamCMDArgs, "anonymous") + } + + steamCMDArgs = append(steamCMDArgs, + "+app_update", ACCServerAppID, + "validate", + "+quit", + ) + + args := steamCMDArgs + + outputCallback(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false) + + timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) + defer cancel() + + callbacks := &command.CallbackConfig{ + OnOutput: outputCallback, + } + + callbackExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbacks, *serverID) + callbackExecutor.ExePath = steamCMDPath + + if err := callbackExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil { + outputCallback(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true) + if timeoutCtx.Err() == context.DeadlineExceeded { + return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required") + } + return fmt.Errorf("failed to run SteamCMD: %v", err) + } + + outputCallback(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false) + + outputCallback(*serverID, "Waiting for Steam operations to complete...", false) + time.Sleep(5 * time.Second) + + exePath := filepath.Join(absPath, "server", "accServer.exe") + outputCallback(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false) + + if _, err := os.Stat(exePath); os.IsNotExist(err) { + outputCallback(*serverID, "accServer.exe not found, checking directory contents...", false) + + if entries, dirErr := os.ReadDir(absPath); dirErr == nil { + outputCallback(*serverID, fmt.Sprintf("Contents of %s:", absPath), false) + for _, entry := range entries { + outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false) + } + } + + serverDir := filepath.Join(absPath, "server") + if entries, dirErr := os.ReadDir(serverDir); dirErr == nil { + outputCallback(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false) + for _, entry := range entries { + outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false) + } + } else { + outputCallback(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true) + } + + outputCallback(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true) + return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath) + } + + outputCallback(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false) + return nil } func (s *SteamService) UninstallServer(installPath string) error { diff --git a/local/service/websocket.go b/local/service/websocket.go index 8e7baca..9ecacca 100644 --- a/local/service/websocket.go +++ b/local/service/websocket.go @@ -11,25 +11,21 @@ import ( "github.com/google/uuid" ) -// WebSocketConnection represents a single WebSocket connection type WebSocketConnection struct { conn *websocket.Conn - serverID *uuid.UUID // If connected to a specific server creation process - userID *uuid.UUID // User who owns this connection + serverID *uuid.UUID + userID *uuid.UUID } -// WebSocketService manages WebSocket connections and message broadcasting type WebSocketService struct { - connections sync.Map // map[string]*WebSocketConnection - key is connection ID + connections sync.Map mu sync.RWMutex } -// NewWebSocketService creates a new WebSocket service func NewWebSocketService() *WebSocketService { return &WebSocketService{} } -// AddConnection adds a new WebSocket connection func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) { wsConn := &WebSocketConnection{ conn: conn, @@ -39,7 +35,6 @@ func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, u logging.Info("WebSocket connection added: %s for user: %v", connID, userID) } -// RemoveConnection removes a WebSocket connection func (ws *WebSocketService) RemoveConnection(connID string) { if conn, exists := ws.connections.LoadAndDelete(connID); exists { if wsConn, ok := conn.(*WebSocketConnection); ok { @@ -49,7 +44,6 @@ func (ws *WebSocketService) RemoveConnection(connID string) { logging.Info("WebSocket connection removed: %s", connID) } -// SetServerID associates a connection with a specific server creation process func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) { if conn, exists := ws.connections.Load(connID); exists { if wsConn, ok := conn.(*WebSocketConnection); ok { @@ -58,7 +52,6 @@ func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) { } } -// BroadcastStep sends a step update to all connections associated with a server func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) { stepMsg := model.StepMessage{ Step: step, @@ -77,7 +70,6 @@ func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerC ws.broadcastToServer(serverID, wsMsg) } -// BroadcastSteamOutput sends Steam command output to all connections associated with a server func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) { steamMsg := model.SteamOutputMessage{ Output: output, @@ -94,7 +86,6 @@ func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output stri ws.broadcastToServer(serverID, wsMsg) } -// BroadcastError sends an error message to all connections associated with a server func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) { errorMsg := model.ErrorMessage{ Error: error, @@ -111,7 +102,6 @@ func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, det ws.broadcastToServer(serverID, wsMsg) } -// BroadcastComplete sends a completion message to all connections associated with a server func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) { completeMsg := model.CompleteMessage{ ServerID: serverID, @@ -129,7 +119,6 @@ func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, ws.broadcastToServer(serverID, wsMsg) } -// broadcastToServer sends a message to all connections associated with a specific server func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) { data, err := json.Marshal(message) if err != nil { @@ -137,22 +126,35 @@ func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model. return } + sentToAssociatedConnections := false + ws.connections.Range(func(key, value interface{}) bool { if wsConn, ok := value.(*WebSocketConnection); ok { - // Send to connections associated with this server if wsConn.serverID != nil && *wsConn.serverID == serverID { if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) - // Remove the connection if it's broken ws.RemoveConnection(key.(string)) + } else { + sentToAssociatedConnections = true } } } return true }) + + if !sentToAssociatedConnections && (message.Type == model.MessageTypeStep || message.Type == model.MessageTypeError || message.Type == model.MessageTypeComplete) { + ws.connections.Range(func(key, value interface{}) bool { + if wsConn, ok := value.(*WebSocketConnection); ok { + if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { + logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) + ws.RemoveConnection(key.(string)) + } + } + return true + }) + } } -// BroadcastToUser sends a message to all connections owned by a specific user func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) { data, err := json.Marshal(message) if err != nil { @@ -162,11 +164,9 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS ws.connections.Range(func(key, value interface{}) bool { if wsConn, ok := value.(*WebSocketConnection); ok { - // Send to connections owned by this user if wsConn.userID != nil && *wsConn.userID == userID { if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) - // Remove the connection if it's broken ws.RemoveConnection(key.(string)) } } @@ -175,7 +175,6 @@ func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebS }) } -// GetActiveConnections returns the count of active connections func (ws *WebSocketService) GetActiveConnections() int { count := 0 ws.connections.Range(func(key, value interface{}) bool { diff --git a/local/service/windows_service.go b/local/service/windows_service.go index c9324e8..32a67c2 100644 --- a/local/service/windows_service.go +++ b/local/service/windows_service.go @@ -23,34 +23,25 @@ func NewWindowsService() *WindowsService { } } -// executeNSSM runs an NSSM command through PowerShell with elevation func (s *WindowsService) ExecuteNSSM(ctx context.Context, args ...string) (string, error) { - // Get NSSM path from environment variable nssmPath := env.GetNSSMPath() - // Prepend NSSM path to arguments nssmArgs := append([]string{"-NoProfile", "-NonInteractive", "-Command", "& " + nssmPath}, args...) output, err := s.executor.ExecuteWithOutput(nssmArgs...) if err != nil { - // Log the full command and error for debugging logging.Error("NSSM command failed: powershell %s", strings.Join(nssmArgs, " ")) logging.Error("NSSM error output: %s", output) return "", err } - // Clean up output by removing null bytes and trimming whitespace cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", "")) - // Remove \r\n from status strings cleaned = strings.TrimSuffix(cleaned, "\r\n") return cleaned, nil } -// Service Installation/Configuration Methods - func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error { - // Ensure paths are absolute and properly formatted for Windows absExecPath, err := filepath.Abs(execPath) if err != nil { return fmt.Errorf("failed to get absolute path for executable: %v", err) @@ -63,30 +54,24 @@ func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPat } absWorkingDir = filepath.Clean(absWorkingDir) - // Log the paths being used logging.Info("Creating service '%s' with:", serviceName) logging.Info(" Executable: %s", absExecPath) logging.Info(" Working Directory: %s", absWorkingDir) - // First remove any existing service with the same name s.ExecuteNSSM(ctx, "remove", serviceName, "confirm") - // Install service if _, err := s.ExecuteNSSM(ctx, "install", serviceName, absExecPath); err != nil { return fmt.Errorf("failed to install service: %v", err) } - // Set arguments if provided if len(args) > 0 { cmdArgs := append([]string{"set", serviceName, "AppParameters"}, args...) if _, err := s.ExecuteNSSM(ctx, cmdArgs...); err != nil { - // Try to clean up on failure s.ExecuteNSSM(ctx, "remove", serviceName, "confirm") return fmt.Errorf("failed to set arguments: %v", err) } } - // Verify service was created if _, err := s.ExecuteNSSM(ctx, "get", serviceName, "Application"); err != nil { return fmt.Errorf("service creation verification failed: %v", err) } @@ -105,17 +90,13 @@ func (s *WindowsService) DeleteService(ctx context.Context, serviceName string) } func (s *WindowsService) UpdateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error { - // First remove the existing service if err := s.DeleteService(ctx, serviceName); err != nil { return err } - // Then create it again with new parameters return s.CreateService(ctx, serviceName, execPath, workingDir, args) } -// Service Control Methods - func (s *WindowsService) Status(ctx context.Context, serviceName string) (string, error) { return s.ExecuteNSSM(ctx, "status", serviceName) } @@ -129,11 +110,9 @@ func (s *WindowsService) Stop(ctx context.Context, serviceName string) (string, } func (s *WindowsService) Restart(ctx context.Context, serviceName string) (string, error) { - // First stop the service if _, err := s.Stop(ctx, serviceName); err != nil { return "", err } - // Then start it again return s.Start(ctx, serviceName) } diff --git a/local/utl/cache/cache.go b/local/utl/cache/cache.go index 082ecac..5cf47e5 100644 --- a/local/utl/cache/cache.go +++ b/local/utl/cache/cache.go @@ -9,26 +9,22 @@ import ( "go.uber.org/dig" ) -// CacheItem represents an item in the cache type CacheItem struct { Value interface{} Expiration int64 } -// InMemoryCache is a thread-safe in-memory cache type InMemoryCache struct { items map[string]CacheItem mu sync.RWMutex } -// NewInMemoryCache creates and returns a new InMemoryCache instance func NewInMemoryCache() *InMemoryCache { return &InMemoryCache{ items: make(map[string]CacheItem), } } -// Set adds an item to the cache with an expiration duration (in seconds) func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) { c.mu.Lock() defer c.mu.Unlock() @@ -44,7 +40,6 @@ func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duratio } } -// Get retrieves an item from the cache func (c *InMemoryCache) Get(key string) (interface{}, bool) { c.mu.RLock() defer c.mu.RUnlock() @@ -55,24 +50,18 @@ func (c *InMemoryCache) Get(key string) (interface{}, bool) { } if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration { - // Item has expired, but don't delete here to avoid lock upgrade. - // It will be overwritten on the next Set. return nil, false } return item.Value, true } -// Delete removes an item from the cache func (c *InMemoryCache) Delete(key string) { c.mu.Lock() defer c.mu.Unlock() delete(c.items, key) } -// GetOrSet retrieves an item from the cache. If the item is not found, it -// calls the provided function to get the value, sets it in the cache, and -// returns it. func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) { if cached, found := c.Get(key); found { if value, ok := cached.(T); ok { @@ -90,7 +79,6 @@ func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetch return value, nil } -// Start initializes the cache and provides it to the DI container. func Start(di *dig.Container) { cache := NewInMemoryCache() err := di.Provide(func() *InMemoryCache { diff --git a/local/utl/command/callback_executor.go b/local/utl/command/callback_executor.go new file mode 100644 index 0000000..2125317 --- /dev/null +++ b/local/utl/command/callback_executor.go @@ -0,0 +1,245 @@ +package command + +import ( + "acc-server-manager/local/model" + "acc-server-manager/local/utl/logging" + "bufio" + "context" + "fmt" + "io" + "os/exec" + "strings" + "time" + + "github.com/google/uuid" +) + +type CallbackInteractiveCommandExecutor struct { + *InteractiveCommandExecutor + callbacks *CallbackConfig + serverID uuid.UUID +} + +func NewCallbackInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, callbacks *CallbackConfig, serverID uuid.UUID) *CallbackInteractiveCommandExecutor { + if callbacks == nil { + callbacks = DefaultCallbackConfig() + } + + return &CallbackInteractiveCommandExecutor{ + InteractiveCommandExecutor: &InteractiveCommandExecutor{ + CommandExecutor: baseExecutor, + tfaManager: tfaManager, + }, + callbacks: callbacks, + serverID: serverID, + } +} + +func (e *CallbackInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error { + cmd := exec.CommandContext(ctx, e.ExePath, args...) + + if e.WorkDir != "" { + cmd.Dir = e.WorkDir + } + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %v", err) + } + defer stdin.Close() + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %v", err) + } + defer stdout.Close() + + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %v", err) + } + defer stderr.Close() + + logging.Info("Executing interactive command with callbacks: %s %s", e.ExePath, strings.Join(args, " ")) + + e.callbacks.OnCommand(e.serverID, e.ExePath, args, false, false, "") + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false) + + if err := cmd.Start(); err != nil { + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to start command: %v", err), true) + return fmt.Errorf("failed to start command: %v", err) + } + + outputDone := make(chan error, 1) + cmdDone := make(chan error, 1) + + go e.monitorOutputWithCallbacks(ctx, stdout, stderr, serverID, outputDone) + + go func() { + cmdDone <- cmd.Wait() + }() + + var cmdErr, outputErr error + completedCount := 0 + + for completedCount < 2 { + select { + case cmdErr = <-cmdDone: + completedCount++ + logging.Info("Command execution completed") + e.callbacks.OnOutput(e.serverID, "Command execution completed", false) + case outputErr = <-outputDone: + completedCount++ + logging.Info("Output monitoring completed") + case <-ctx.Done(): + e.callbacks.OnOutput(e.serverID, "Command execution cancelled", true) + return ctx.Err() + } + } + + if outputErr != nil { + logging.Warn("Output monitoring error: %v", outputErr) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Output monitoring error: %v", outputErr), true) + } + + success := cmdErr == nil + errorMsg := "" + if cmdErr != nil { + errorMsg = cmdErr.Error() + } + e.callbacks.OnCommand(e.serverID, e.ExePath, args, true, success, errorMsg) + + return cmdErr +} + +func (e *CallbackInteractiveCommandExecutor) monitorOutputWithCallbacks(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) { + defer func() { + select { + case done <- nil: + default: + } + }() + + stdoutScanner := bufio.NewScanner(stdout) + stderrScanner := bufio.NewScanner(stderr) + + outputChan := make(chan outputLine, 100) + readersDone := make(chan struct{}, 2) + + steamConsoleStarted := false + tfaRequestCreated := false + + go func() { + defer func() { readersDone <- struct{}{} }() + for stdoutScanner.Scan() { + line := stdoutScanner.Text() + if e.LogOutput { + logging.Info("STDOUT: %s", line) + } + e.callbacks.OnOutput(e.serverID, line, false) + + select { + case outputChan <- outputLine{text: line, isError: false}: + case <-ctx.Done(): + return + } + } + if err := stdoutScanner.Err(); err != nil { + logging.Warn("Stdout scanner error: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stdout scanner error: %v", err), true) + } + }() + + go func() { + defer func() { readersDone <- struct{}{} }() + for stderrScanner.Scan() { + line := stderrScanner.Text() + if e.LogOutput { + logging.Info("STDERR: %s", line) + } + e.callbacks.OnOutput(e.serverID, line, true) + + select { + case outputChan <- outputLine{text: line, isError: true}: + case <-ctx.Done(): + return + } + } + if err := stderrScanner.Err(); err != nil { + logging.Warn("Stderr scanner error: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stderr scanner error: %v", err), true) + } + }() + + readersFinished := 0 + for { + select { + case <-ctx.Done(): + done <- ctx.Err() + return + case <-readersDone: + readersFinished++ + if readersFinished == 2 { + close(outputChan) + for lineData := range outputChan { + if e.is2FAPrompt(lineData.text) { + if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { + logging.Error("Failed to handle 2FA prompt: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) + done <- err + return + } + } + } + return + } + case lineData, ok := <-outputChan: + if !ok { + return + } + + lowerLine := strings.ToLower(lineData.text) + if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") { + steamConsoleStarted = true + logging.Info("Steam Console Client startup detected - will monitor for 2FA hang") + e.callbacks.OnOutput(e.serverID, "Steam Console Client startup detected", false) + } + + if e.is2FAPrompt(lineData.text) { + if !tfaRequestCreated { + e.callbacks.OnOutput(e.serverID, "2FA prompt detected - waiting for user confirmation", false) + if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { + logging.Error("Failed to handle 2FA prompt: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) + done <- err + return + } + tfaRequestCreated = true + } + } + + if tfaRequestCreated && e.isSteamContinuing(lineData.text) { + logging.Info("Steam CMD appears to have continued after 2FA confirmation") + e.callbacks.OnOutput(e.serverID, "Steam CMD continued after 2FA confirmation", false) + e.autoCompletePendingRequests(serverID) + } + case <-time.After(15 * time.Second): + if steamConsoleStarted && !tfaRequestCreated { + logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA") + e.callbacks.OnOutput(e.serverID, "Waiting for Steam Guard 2FA confirmation...", false) + if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil { + logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err) + e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true) + done <- err + return + } + tfaRequestCreated = true + } + } + } +} + +type outputLine struct { + text string + isError bool +} diff --git a/local/utl/command/callbacks.go b/local/utl/command/callbacks.go new file mode 100644 index 0000000..1d1d060 --- /dev/null +++ b/local/utl/command/callbacks.go @@ -0,0 +1,19 @@ +package command + +import "github.com/google/uuid" + +type OutputCallback func(serverID uuid.UUID, output string, isError bool) + +type CommandCallback func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) + +type CallbackConfig struct { + OnOutput OutputCallback + OnCommand CommandCallback +} + +func DefaultCallbackConfig() *CallbackConfig { + return &CallbackConfig{ + OnOutput: func(uuid.UUID, string, bool) {}, + OnCommand: func(uuid.UUID, string, []string, bool, bool, string) {}, + } +} diff --git a/local/utl/command/executor.go b/local/utl/command/executor.go index 75c0d07..256b5ee 100644 --- a/local/utl/command/executor.go +++ b/local/utl/command/executor.go @@ -8,17 +8,12 @@ import ( "strings" ) -// CommandExecutor provides a base structure for executing commands type CommandExecutor struct { - // Base executable path - ExePath string - // Working directory for commands - WorkDir string - // Whether to capture and log output + ExePath string + WorkDir string LogOutput bool } -// CommandBuilder helps build command arguments type CommandBuilder struct { args []string } @@ -48,10 +43,9 @@ func (b *CommandBuilder) Build() []string { return b.args } -// Execute runs a command with the given arguments func (e *CommandExecutor) Execute(args ...string) error { cmd := exec.Command(e.ExePath, args...) - + if e.WorkDir != "" { cmd.Dir = e.WorkDir } @@ -65,15 +59,13 @@ func (e *CommandExecutor) Execute(args ...string) error { return cmd.Run() } -// ExecuteWithBuilder runs a command using a CommandBuilder func (e *CommandExecutor) ExecuteWithBuilder(builder *CommandBuilder) error { return e.Execute(builder.Build()...) } -// ExecuteWithOutput runs a command and returns its output func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) { cmd := exec.Command(e.ExePath, args...) - + if e.WorkDir != "" { cmd.Dir = e.WorkDir } @@ -83,10 +75,9 @@ func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) { return string(output), err } -// ExecuteWithEnv runs a command with custom environment variables func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error { cmd := exec.Command(e.ExePath, args...) - + if e.WorkDir != "" { cmd.Dir = e.WorkDir } @@ -100,4 +91,4 @@ func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error { logging.Info("Executing command: %s %s", e.ExePath, strings.Join(args, " ")) return cmd.Run() -} \ No newline at end of file +} diff --git a/local/utl/command/interactive_executor.go b/local/utl/command/interactive_executor.go index 5b90ba2..1e6ff47 100644 --- a/local/utl/command/interactive_executor.go +++ b/local/utl/command/interactive_executor.go @@ -9,14 +9,12 @@ import ( "io" "os" "os/exec" - "reflect" "strings" "time" "github.com/google/uuid" ) -// InteractiveCommandExecutor extends CommandExecutor to handle interactive commands type InteractiveCommandExecutor struct { *CommandExecutor tfaManager *model.Steam2FAManager @@ -29,7 +27,6 @@ func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *mo } } -// ExecuteInteractive runs a command that may require 2FA input func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error { cmd := exec.CommandContext(ctx, e.ExePath, args...) @@ -37,7 +34,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser cmd.Dir = e.WorkDir } - // Create pipes for stdin, stdout, and stderr stdin, err := cmd.StdinPipe() if err != nil { return fmt.Errorf("failed to create stdin pipe: %v", err) @@ -58,7 +54,6 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " ")) - // Enable debug mode if environment variable is set debugMode := os.Getenv("STEAMCMD_DEBUG") == "true" if debugMode { logging.Info("STEAMCMD_DEBUG mode enabled - will log all output and create proactive 2FA requests") @@ -68,19 +63,15 @@ func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, ser return fmt.Errorf("failed to start command: %v", err) } - // Create channels for output monitoring outputDone := make(chan error, 1) cmdDone := make(chan error, 1) - // Monitor stdout and stderr for 2FA prompts go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone) - // Wait for the command to finish in a separate goroutine go func() { cmdDone <- cmd.Wait() }() - // Wait for both command and output monitoring to complete var cmdErr, outputErr error completedCount := 0 @@ -112,18 +103,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } }() - // Create scanners for both outputs stdoutScanner := bufio.NewScanner(stdout) stderrScanner := bufio.NewScanner(stderr) - outputChan := make(chan string, 100) // Buffered channel to prevent blocking + outputChan := make(chan string, 100) readersDone := make(chan struct{}, 2) - // Track Steam Console startup for this specific execution steamConsoleStarted := false tfaRequestCreated := false - // Read from stdout go func() { defer func() { readersDone <- struct{}{} }() for stdoutScanner.Scan() { @@ -131,7 +119,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, if e.LogOutput { logging.Info("STDOUT: %s", line) } - // Always log Steam CMD output for debugging 2FA issues if strings.Contains(strings.ToLower(line), "steam") { logging.Info("STEAM_DEBUG: %s", line) } @@ -146,7 +133,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } }() - // Read from stderr go func() { defer func() { readersDone <- struct{}{} }() for stderrScanner.Scan() { @@ -154,7 +140,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, if e.LogOutput { logging.Info("STDERR: %s", line) } - // Always log Steam CMD errors for debugging 2FA issues if strings.Contains(strings.ToLower(line), "steam") { logging.Info("STEAM_DEBUG_ERR: %s", line) } @@ -169,7 +154,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } }() - // Monitor for completion and 2FA prompts readersFinished := 0 for { select { @@ -179,9 +163,7 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, case <-readersDone: readersFinished++ if readersFinished == 2 { - // Both readers are done, close output channel and finish monitoring close(outputChan) - // Drain any remaining output for line := range outputChan { if e.is2FAPrompt(line) { if err := e.handle2FAPrompt(ctx, line, serverID); err != nil { @@ -195,18 +177,15 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } case line, ok := <-outputChan: if !ok { - // Channel closed, we're done return } - // Check for Steam Console startup lowerLine := strings.ToLower(line) if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") { steamConsoleStarted = true logging.Info("Steam Console Client startup detected - will monitor for 2FA hang") } - // Check if this line indicates a 2FA prompt if e.is2FAPrompt(line) { if !tfaRequestCreated { if err := e.handle2FAPrompt(ctx, line, serverID); err != nil { @@ -218,15 +197,11 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } } - // Check if Steam CMD continued after 2FA (auto-completion) if tfaRequestCreated && e.isSteamContinuing(line) { logging.Info("Steam CMD appears to have continued after 2FA confirmation - auto-completing 2FA request") - // Auto-complete any pending 2FA requests for this server e.autoCompletePendingRequests(serverID) } case <-time.After(15 * time.Second): - // If Steam Console has started and we haven't seen output for 15 seconds, - // it's very likely waiting for 2FA confirmation if steamConsoleStarted && !tfaRequestCreated { logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA") if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil { @@ -243,7 +218,6 @@ func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, } func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool { - // Common SteamCMD 2FA prompts - updated with more comprehensive patterns twoFAKeywords := []string{ "please enter your steam guard code", "steam guard", @@ -272,7 +246,6 @@ func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool { } } - // Also check for patterns that might indicate Steam is waiting for input waitingPatterns := []string{ "waiting for", "please enter", @@ -330,12 +303,9 @@ func (e *InteractiveCommandExecutor) autoCompletePendingRequests(serverID *uuid. func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error { logging.Info("2FA prompt detected: %s", promptLine) - // Create a 2FA request request := e.tfaManager.CreateRequest(promptLine, serverID) logging.Info("Created 2FA request with ID: %s", request.ID) - // Wait for user to complete the 2FA process - // Use a reasonable timeout (e.g., 5 minutes) timeout := 5 * time.Minute success, err := e.tfaManager.WaitForCompletion(request.ID, timeout) @@ -352,271 +322,3 @@ func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLi logging.Info("2FA completed successfully") return nil } - -// WebSocketInteractiveCommandExecutor extends InteractiveCommandExecutor to stream output via WebSocket -type WebSocketInteractiveCommandExecutor struct { - *InteractiveCommandExecutor - wsService interface{} // Using interface{} to avoid circular import - serverID uuid.UUID -} - -// NewInteractiveCommandExecutorWithWebSocket creates a new WebSocket-enabled interactive command executor -func NewInteractiveCommandExecutorWithWebSocket(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, wsService interface{}, serverID uuid.UUID) *WebSocketInteractiveCommandExecutor { - return &WebSocketInteractiveCommandExecutor{ - InteractiveCommandExecutor: &InteractiveCommandExecutor{ - CommandExecutor: baseExecutor, - tfaManager: tfaManager, - }, - wsService: wsService, - serverID: serverID, - } -} - -// ExecuteInteractive runs a command with WebSocket output streaming -func (e *WebSocketInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error { - cmd := exec.CommandContext(ctx, e.ExePath, args...) - - if e.WorkDir != "" { - cmd.Dir = e.WorkDir - } - - // Create pipes for stdin, stdout, and stderr - stdin, err := cmd.StdinPipe() - if err != nil { - return fmt.Errorf("failed to create stdin pipe: %v", err) - } - defer stdin.Close() - - stdout, err := cmd.StdoutPipe() - if err != nil { - return fmt.Errorf("failed to create stdout pipe: %v", err) - } - defer stdout.Close() - - stderr, err := cmd.StderrPipe() - if err != nil { - return fmt.Errorf("failed to create stderr pipe: %v", err) - } - defer stderr.Close() - - logging.Info("Executing interactive command with WebSocket streaming: %s %s", e.ExePath, strings.Join(args, " ")) - - // Broadcast command start via WebSocket - e.broadcastSteamOutput(fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false) - - if err := cmd.Start(); err != nil { - e.broadcastSteamOutput(fmt.Sprintf("Failed to start command: %v", err), true) - return fmt.Errorf("failed to start command: %v", err) - } - - // Create channels for output monitoring - outputDone := make(chan error, 1) - cmdDone := make(chan error, 1) - - // Monitor stdout and stderr for 2FA prompts with WebSocket streaming - go e.monitorOutputWithWebSocket(ctx, stdout, stderr, serverID, outputDone) - - // Wait for the command to finish in a separate goroutine - go func() { - cmdDone <- cmd.Wait() - }() - - // Wait for both command and output monitoring to complete - var cmdErr, outputErr error - completedCount := 0 - - for completedCount < 2 { - select { - case cmdErr = <-cmdDone: - completedCount++ - logging.Info("Command execution completed") - e.broadcastSteamOutput("Command execution completed", false) - case outputErr = <-outputDone: - completedCount++ - logging.Info("Output monitoring completed") - case <-ctx.Done(): - e.broadcastSteamOutput("Command execution cancelled", true) - return ctx.Err() - } - } - - if outputErr != nil { - logging.Warn("Output monitoring error: %v", outputErr) - e.broadcastSteamOutput(fmt.Sprintf("Output monitoring error: %v", outputErr), true) - } - - return cmdErr -} - -// broadcastSteamOutput sends output to WebSocket using reflection to avoid circular imports -func (e *WebSocketInteractiveCommandExecutor) broadcastSteamOutput(output string, isError bool) { - if e.wsService == nil { - return - } - - // Use reflection to call BroadcastSteamOutput method - wsServiceVal := reflect.ValueOf(e.wsService) - method := wsServiceVal.MethodByName("BroadcastSteamOutput") - if !method.IsValid() { - logging.Warn("BroadcastSteamOutput method not found on WebSocket service") - return - } - - // Call the method with parameters: serverID, output, isError - args := []reflect.Value{ - reflect.ValueOf(e.serverID), - reflect.ValueOf(output), - reflect.ValueOf(isError), - } - method.Call(args) -} - -// monitorOutputWithWebSocket monitors command output and streams it via WebSocket -func (e *WebSocketInteractiveCommandExecutor) monitorOutputWithWebSocket(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) { - defer func() { - select { - case done <- nil: - default: - } - }() - - // Create scanners for both outputs - stdoutScanner := bufio.NewScanner(stdout) - stderrScanner := bufio.NewScanner(stderr) - - outputChan := make(chan outputLine, 100) // Buffered channel to prevent blocking - readersDone := make(chan struct{}, 2) - - // Track Steam Console startup for this specific execution - steamConsoleStarted := false - tfaRequestCreated := false - - // Read from stdout - go func() { - defer func() { readersDone <- struct{}{} }() - for stdoutScanner.Scan() { - line := stdoutScanner.Text() - if e.LogOutput { - logging.Info("STDOUT: %s", line) - } - // Stream output via WebSocket - e.broadcastSteamOutput(line, false) - - select { - case outputChan <- outputLine{text: line, isError: false}: - case <-ctx.Done(): - return - } - } - if err := stdoutScanner.Err(); err != nil { - logging.Warn("Stdout scanner error: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Stdout scanner error: %v", err), true) - } - }() - - // Read from stderr - go func() { - defer func() { readersDone <- struct{}{} }() - for stderrScanner.Scan() { - line := stderrScanner.Text() - if e.LogOutput { - logging.Info("STDERR: %s", line) - } - // Stream error output via WebSocket - e.broadcastSteamOutput(line, true) - - select { - case outputChan <- outputLine{text: line, isError: true}: - case <-ctx.Done(): - return - } - } - if err := stderrScanner.Err(); err != nil { - logging.Warn("Stderr scanner error: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Stderr scanner error: %v", err), true) - } - }() - - // Monitor for completion and 2FA prompts - readersFinished := 0 - for { - select { - case <-ctx.Done(): - done <- ctx.Err() - return - case <-readersDone: - readersFinished++ - if readersFinished == 2 { - // Both readers are done, close output channel and finish monitoring - close(outputChan) - // Drain any remaining output - for lineData := range outputChan { - if e.is2FAPrompt(lineData.text) { - if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { - logging.Error("Failed to handle 2FA prompt: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) - done <- err - return - } - } - } - return - } - case lineData, ok := <-outputChan: - if !ok { - // Channel closed, we're done - return - } - - // Check for Steam Console startup - lowerLine := strings.ToLower(lineData.text) - if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") { - steamConsoleStarted = true - logging.Info("Steam Console Client startup detected - will monitor for 2FA hang") - e.broadcastSteamOutput("Steam Console Client startup detected", false) - } - - // Check if this line indicates a 2FA prompt - if e.is2FAPrompt(lineData.text) { - if !tfaRequestCreated { - e.broadcastSteamOutput("2FA prompt detected - waiting for user confirmation", false) - if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil { - logging.Error("Failed to handle 2FA prompt: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true) - done <- err - return - } - tfaRequestCreated = true - } - } - - // Check if Steam CMD continued after 2FA (auto-completion) - if tfaRequestCreated && e.isSteamContinuing(lineData.text) { - logging.Info("Steam CMD appears to have continued after 2FA confirmation") - e.broadcastSteamOutput("Steam CMD continued after 2FA confirmation", false) - // Auto-complete any pending 2FA requests for this server - e.autoCompletePendingRequests(serverID) - } - case <-time.After(15 * time.Second): - // If Steam Console has started and we haven't seen output for 15 seconds, - // it's very likely waiting for 2FA confirmation - if steamConsoleStarted && !tfaRequestCreated { - logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA") - e.broadcastSteamOutput("Waiting for Steam Guard 2FA confirmation...", false) - if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil { - logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err) - e.broadcastSteamOutput(fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true) - done <- err - return - } - tfaRequestCreated = true - } - } - } -} - -// outputLine represents a line of output with error status -type outputLine struct { - text string - isError bool -} diff --git a/local/utl/common/common.go b/local/utl/common/common.go index 825e4c0..4757b59 100644 --- a/local/utl/common/common.go +++ b/local/utl/common/common.go @@ -79,12 +79,6 @@ func IndentJson(body []byte) ([]byte, error) { return unmarshaledBody.Bytes(), nil } -// ParseQueryFilter parses query parameters into a filter struct using reflection. -// It supports various field types and uses struct tags to determine parsing behavior. -// Supported tags: -// - `query:"field_name"` - specifies the query parameter name -// - `param:"param_name"` - specifies the path parameter name -// - `time_format:"format"` - specifies the time format for parsing dates (default: RFC3339) func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { val := reflect.ValueOf(filter) if val.Kind() != reflect.Ptr || val.IsNil() { @@ -94,14 +88,12 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { elem := val.Elem() typ := elem.Type() - // Process all fields including embedded structs var processFields func(reflect.Value, reflect.Type) error processFields = func(val reflect.Value, typ reflect.Type) error { for i := 0; i < val.NumField(); i++ { field := val.Field(i) fieldType := typ.Field(i) - // Handle embedded structs recursively if fieldType.Anonymous { if err := processFields(field, fieldType.Type); err != nil { return err @@ -109,12 +101,10 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { continue } - // Skip if field cannot be set if !field.CanSet() { continue } - // Check for param tag first (path parameters) if paramName := fieldType.Tag.Get("param"); paramName != "" { if err := parsePathParam(c, field, paramName); err != nil { return fmt.Errorf("error parsing path parameter %s: %v", paramName, err) @@ -122,15 +112,14 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error { continue } - // Then check for query tag queryName := fieldType.Tag.Get("query") if queryName == "" { - queryName = ToSnakeCase(fieldType.Name) // Default to snake_case of field name + queryName = ToSnakeCase(fieldType.Name) } queryVal := c.Query(queryName) if queryVal == "" { - continue // Skip empty values + continue } if err := parseValue(field, queryVal, fieldType.Tag); err != nil { diff --git a/local/utl/configs/configs.go b/local/utl/configs/configs.go index 0a7b7a2..34074af 100644 --- a/local/utl/configs/configs.go +++ b/local/utl/configs/configs.go @@ -18,7 +18,6 @@ var ( func Init() { godotenv.Load() - // Fail fast if critical environment variables are missing Secret = getEnvRequired("APP_SECRET") SecretCode = getEnvRequired("APP_SECRET_CODE") EncryptionKey = getEnvRequired("ENCRYPTION_KEY") @@ -29,7 +28,6 @@ func Init() { } } -// getEnv retrieves an environment variable or returns a fallback value. func getEnv(key, fallback string) string { if value, exists := os.LookupEnv(key); exists { return value @@ -38,12 +36,10 @@ func getEnv(key, fallback string) string { return fallback } -// getEnvRequired retrieves an environment variable and fails if it's not set. -// This should be used for critical configuration that must not have defaults. func getEnvRequired(key string) string { if value, exists := os.LookupEnv(key); exists && value != "" { return value } log.Fatalf("Required environment variable %s is not set or is empty", key) - return "" // This line will never be reached due to log.Fatalf + return "" } diff --git a/local/utl/db/db.go b/local/utl/db/db.go index 8c45bed..7efa670 100644 --- a/local/utl/db/db.go +++ b/local/utl/db/db.go @@ -33,7 +33,6 @@ func Start(di *dig.Container) { func Migrate(db *gorm.DB) { logging.Info("Migrating database") - // Run GORM AutoMigrate for all models err := db.AutoMigrate( &model.ServiceControlModel{}, &model.Config{}, @@ -52,7 +51,6 @@ func Migrate(db *gorm.DB) { if err != nil { logging.Error("GORM AutoMigrate failed: %v", err) - // Don't panic, just log the error as custom migrations may have handled this } db.FirstOrCreate(&model.ServiceControlModel{ServiceControl: "Works"}) @@ -63,10 +61,8 @@ func Migrate(db *gorm.DB) { func runMigrations(db *gorm.DB) { logging.Info("Running custom database migrations...") - // Migration 001: Password security upgrade if err := migrations.RunPasswordSecurityMigration(db); err != nil { logging.Error("Failed to run password security migration: %v", err) - // Continue - this migration might not be needed for all setups } logging.Info("Custom database migrations completed") @@ -132,7 +128,6 @@ func seedCarModels(db *gorm.DB) error { carModels := []model.CarModel{ {Value: 0, CarModel: "Porsche 991 GT3 R"}, {Value: 1, CarModel: "Mercedes-AMG GT3"}, - // ... Add all car models from your list } for _, cm := range carModels { diff --git a/local/utl/env/env.go b/local/utl/env/env.go index 2a7c4b7..0f41b84 100644 --- a/local/utl/env/env.go +++ b/local/utl/env/env.go @@ -6,12 +6,10 @@ import ( ) const ( - // Default paths for when environment variables are not set DefaultSteamCMDPath = "c:\\steamcmd\\steamcmd.exe" DefaultNSSMPath = ".\\nssm.exe" ) -// GetSteamCMDPath returns the SteamCMD executable path from environment variable or default func GetSteamCMDPath() string { if path := os.Getenv("STEAMCMD_PATH"); path != "" { return path @@ -19,13 +17,11 @@ func GetSteamCMDPath() string { return DefaultSteamCMDPath } -// GetSteamCMDDirPath returns the directory containing SteamCMD executable func GetSteamCMDDirPath() string { steamCMDPath := GetSteamCMDPath() return filepath.Dir(steamCMDPath) } -// GetNSSMPath returns the NSSM executable path from environment variable or default func GetNSSMPath() string { if path := os.Getenv("NSSM_PATH"); path != "" { return path @@ -33,17 +29,14 @@ func GetNSSMPath() string { return DefaultNSSMPath } -// ValidatePaths checks if the configured paths exist (optional validation) func ValidatePaths() map[string]error { errors := make(map[string]error) - // Check SteamCMD path steamCMDPath := GetSteamCMDPath() if _, err := os.Stat(steamCMDPath); os.IsNotExist(err) { errors["STEAMCMD_PATH"] = err } - // Check NSSM path nssmPath := GetNSSMPath() if _, err := os.Stat(nssmPath); os.IsNotExist(err) { errors["NSSM_PATH"] = err diff --git a/local/utl/error_handler/controller_error_handler.go b/local/utl/error_handler/controller_error_handler.go index 1d4408a..067371f 100644 --- a/local/utl/error_handler/controller_error_handler.go +++ b/local/utl/error_handler/controller_error_handler.go @@ -9,45 +9,37 @@ import ( "github.com/gofiber/fiber/v2" ) -// ControllerErrorHandler provides centralized error handling for controllers type ControllerErrorHandler struct { errorLogger *logging.ErrorLogger } -// NewControllerErrorHandler creates a new controller error handler instance func NewControllerErrorHandler() *ControllerErrorHandler { return &ControllerErrorHandler{ errorLogger: logging.GetErrorLogger(), } } -// ErrorResponse represents a standardized error response type ErrorResponse struct { Error string `json:"error"` Code int `json:"code,omitempty"` Details map[string]string `json:"details,omitempty"` } -// HandleError handles controller errors with logging and standardized responses func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error { if err == nil { return nil } - // Get caller information for logging _, file, line, _ := runtime.Caller(1) file = strings.TrimPrefix(file, "acc-server-manager/") - // Build context string contextStr := "" if len(context) > 0 { contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|")) } - // Clean error message (remove null bytes) cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "") - // Log the error with context ceh.errorLogger.LogWithContext( fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line), "%s%s", @@ -55,31 +47,26 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo cleanErrorMsg, ) - // Create standardized error response errorResponse := ErrorResponse{ Error: cleanErrorMsg, Code: statusCode, } - // Add request details if available if c != nil { if errorResponse.Details == nil { errorResponse.Details = make(map[string]string) } - - // Safely extract request details + func() { defer func() { if r := recover(); r != nil { - // If any of these panic, just skip adding the details return } }() - + errorResponse.Details["method"] = c.Method() errorResponse.Details["path"] = c.Path() - - // Safely get IP address + if ip := c.IP(); ip != "" { errorResponse.Details["ip"] = ip } else { @@ -88,14 +75,11 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo }() } - // Return appropriate response based on status code if c == nil { - // If context is nil, we can't return a response return fmt.Errorf("cannot return HTTP response: context is nil") } - + if statusCode >= 500 { - // For server errors, don't expose internal details return c.Status(statusCode).JSON(ErrorResponse{ Error: "Internal server error", Code: statusCode, @@ -105,52 +89,42 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo return c.Status(statusCode).JSON(errorResponse) } -// HandleValidationError handles validation errors specifically func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error { return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field) } -// HandleDatabaseError handles database-related errors func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE") } -// HandleAuthError handles authentication/authorization errors func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH") } -// HandleNotFoundError handles resource not found errors func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error { err := fmt.Errorf("%s not found", resource) return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND") } -// HandleBusinessLogicError handles business logic errors func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC") } -// HandleServiceError handles service layer errors func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE") } -// HandleParsingError handles request parsing errors func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error { return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING") } -// HandleUUIDError handles UUID parsing errors func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error { err := fmt.Errorf("invalid %s format", field) return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field) } -// Global controller error handler instance var globalErrorHandler *ControllerErrorHandler -// GetControllerErrorHandler returns the global controller error handler instance func GetControllerErrorHandler() *ControllerErrorHandler { if globalErrorHandler == nil { globalErrorHandler = NewControllerErrorHandler() @@ -158,49 +132,38 @@ func GetControllerErrorHandler() *ControllerErrorHandler { return globalErrorHandler } -// Convenience functions using the global error handler - -// HandleError handles controller errors using the global error handler func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error { return GetControllerErrorHandler().HandleError(c, err, statusCode, context...) } -// HandleValidationError handles validation errors using the global error handler func HandleValidationError(c *fiber.Ctx, err error, field string) error { return GetControllerErrorHandler().HandleValidationError(c, err, field) } -// HandleDatabaseError handles database errors using the global error handler func HandleDatabaseError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleDatabaseError(c, err) } -// HandleAuthError handles auth errors using the global error handler func HandleAuthError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleAuthError(c, err) } -// HandleNotFoundError handles not found errors using the global error handler func HandleNotFoundError(c *fiber.Ctx, resource string) error { return GetControllerErrorHandler().HandleNotFoundError(c, resource) } -// HandleBusinessLogicError handles business logic errors using the global error handler func HandleBusinessLogicError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleBusinessLogicError(c, err) } -// HandleServiceError handles service errors using the global error handler func HandleServiceError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleServiceError(c, err) } -// HandleParsingError handles parsing errors using the global error handler func HandleParsingError(c *fiber.Ctx, err error) error { return GetControllerErrorHandler().HandleParsingError(c, err) } -// HandleUUIDError handles UUID errors using the global error handler func HandleUUIDError(c *fiber.Ctx, field string) error { return GetControllerErrorHandler().HandleUUIDError(c, field) } diff --git a/local/utl/jwt/jwt.go b/local/utl/jwt/jwt.go index 9cd036f..9d4cdff 100644 --- a/local/utl/jwt/jwt.go +++ b/local/utl/jwt/jwt.go @@ -11,7 +11,6 @@ import ( "github.com/golang-jwt/jwt/v4" ) -// Claims represents the JWT claims. type Claims struct { UserID string `json:"user_id"` IsOpenToken bool `json:"is_open_token"` @@ -27,7 +26,6 @@ type OpenJWTHandler struct { *JWTHandler } -// NewJWTHandler creates a new JWTHandler instance with the provided secret key. func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler { jwtHandler := NewJWTHandler(jwtSecret) jwtHandler.IsOpenToken = true @@ -36,7 +34,6 @@ func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler { } } -// NewJWTHandler creates a new JWTHandler instance with the provided secret key. func NewJWTHandler(jwtSecret string) *JWTHandler { if jwtSecret == "" { errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty") @@ -44,14 +41,12 @@ func NewJWTHandler(jwtSecret string) *JWTHandler { var secretKey []byte - // Decode base64 secret if it looks like base64, otherwise use as-is if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 { secretKey = decoded } else { secretKey = []byte(jwtSecret) } - // Ensure minimum key length for security if len(secretKey) < 32 { errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security") } @@ -60,8 +55,6 @@ func NewJWTHandler(jwtSecret string) *JWTHandler { } } -// GenerateSecretKey generates a cryptographically secure random key for JWT signing -// This is a utility function for generating new secrets, not used in normal operation func (jh *JWTHandler) GenerateSecretKey() string { key := make([]byte, 64) // 512 bits if _, err := rand.Read(key); err != nil { @@ -70,7 +63,6 @@ func (jh *JWTHandler) GenerateSecretKey() string { return base64.StdEncoding.EncodeToString(key) } -// GenerateToken generates a new JWT for a given user. func (jh *JWTHandler) GenerateToken(userId string) (string, error) { expirationTime := time.Now().Add(24 * time.Hour) claims := &Claims{ @@ -99,7 +91,6 @@ func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time return token.SignedString(jh.SecretKey) } -// ValidateToken validates a JWT and returns the claims if the token is valid. func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) { claims := &Claims{} diff --git a/local/utl/logging/base.go b/local/utl/logging/base.go index 03ef1fa..c55b7bb 100644 --- a/local/utl/logging/base.go +++ b/local/utl/logging/base.go @@ -15,7 +15,6 @@ var ( timeFormat = "2006-01-02 15:04:05.000" ) -// BaseLogger provides the core logging functionality type BaseLogger struct { file *os.File logger *log.Logger @@ -23,7 +22,6 @@ type BaseLogger struct { initialized bool } -// LogLevel represents different logging levels type LogLevel string const ( @@ -34,28 +32,23 @@ const ( LogLevelPanic LogLevel = "PANIC" ) -// Initialize creates a new base logger instance func InitializeBase(tp string) (*BaseLogger, error) { return newBaseLogger(tp) } func newBaseLogger(tp string) (*BaseLogger, error) { - // Ensure logs directory exists if err := os.MkdirAll("logs", 0755); err != nil { return nil, fmt.Errorf("failed to create logs directory: %v", err) } - // Open log file with date in name logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp)) file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666) if err != nil { return nil, fmt.Errorf("failed to open log file: %v", err) } - // Create multi-writer for both file and console multiWriter := io.MultiWriter(file, os.Stdout) - // Create base logger logger := &BaseLogger{ file: file, logger: log.New(multiWriter, "", 0), @@ -65,13 +58,11 @@ func newBaseLogger(tp string) (*BaseLogger, error) { return logger, nil } -// GetBaseLogger creates and returns a new base logger instance func GetBaseLogger(tp string) *BaseLogger { baseLogger, _ := InitializeBase(tp) return baseLogger } -// Close closes the log file func (bl *BaseLogger) Close() error { bl.mu.Lock() defer bl.mu.Unlock() @@ -82,7 +73,6 @@ func (bl *BaseLogger) Close() error { return nil } -// Log writes a log entry with the specified level func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) { if bl == nil || !bl.initialized { return @@ -91,14 +81,11 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) { bl.mu.RLock() defer bl.mu.RUnlock() - // Get caller info (skip 2 frames: this function and the calling Log function) _, file, line, _ := runtime.Caller(2) file = filepath.Base(file) - // Format message msg := fmt.Sprintf(format, v...) - // Format final log line logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s", time.Now().Format(timeFormat), string(level), @@ -110,7 +97,6 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) { bl.logger.Println(logLine) } -// LogWithCaller writes a log entry with custom caller depth func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) { if bl == nil || !bl.initialized { return @@ -119,14 +105,11 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri bl.mu.RLock() defer bl.mu.RUnlock() - // Get caller info with custom depth _, file, line, _ := runtime.Caller(callerDepth) file = filepath.Base(file) - // Format message msg := fmt.Sprintf(format, v...) - // Format final log line logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s", time.Now().Format(timeFormat), string(level), @@ -138,7 +121,6 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri bl.logger.Println(logLine) } -// IsInitialized returns whether the base logger is initialized func (bl *BaseLogger) IsInitialized() bool { if bl == nil { return false @@ -148,19 +130,16 @@ func (bl *BaseLogger) IsInitialized() bool { return bl.initialized } -// RecoverAndLog recovers from panics and logs them func RecoverAndLog() { baseLogger := GetBaseLogger("panic") if baseLogger != nil && baseLogger.IsInitialized() { if r := recover(); r != nil { - // Get stack trace buf := make([]byte, 4096) n := runtime.Stack(buf, false) stackTrace := string(buf[:n]) baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace) - // Re-panic to maintain original behavior if needed panic(r) } } diff --git a/local/utl/logging/debug.go b/local/utl/logging/debug.go index 0422ad6..a1bbf6d 100644 --- a/local/utl/logging/debug.go +++ b/local/utl/logging/debug.go @@ -6,12 +6,10 @@ import ( "sync" ) -// DebugLogger handles debug-level logging type DebugLogger struct { base *BaseLogger } -// NewDebugLogger creates a new debug logger instance func NewDebugLogger() *DebugLogger { base, _ := InitializeBase("debug") return &DebugLogger{ @@ -19,14 +17,12 @@ func NewDebugLogger() *DebugLogger { } } -// Log writes a debug-level log entry func (dl *DebugLogger) Log(format string, v ...interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, format, v...) } } -// LogWithContext writes a debug-level log entry with additional context func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) { if dl.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -34,7 +30,6 @@ func (dl *DebugLogger) LogWithContext(context string, format string, v ...interf } } -// LogFunction logs function entry and exit for debugging func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) { if dl.base != nil { if len(args) > 0 { @@ -45,21 +40,18 @@ func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) { } } -// LogVariable logs variable values for debugging func (dl *DebugLogger) LogVariable(varName string, value interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value) } } -// LogState logs application state information func (dl *DebugLogger) LogState(component string, state interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state) } } -// LogSQL logs SQL queries for debugging func (dl *DebugLogger) LogSQL(query string, args ...interface{}) { if dl.base != nil { if len(args) > 0 { @@ -70,7 +62,6 @@ func (dl *DebugLogger) LogSQL(query string, args ...interface{}) { } } -// LogMemory logs memory usage information func (dl *DebugLogger) LogMemory() { if dl.base != nil { var m runtime.MemStats @@ -80,32 +71,27 @@ func (dl *DebugLogger) LogMemory() { } } -// LogGoroutines logs current number of goroutines func (dl *DebugLogger) LogGoroutines() { if dl.base != nil { dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine()) } } -// LogTiming logs timing information for performance debugging func (dl *DebugLogger) LogTiming(operation string, duration interface{}) { if dl.base != nil { dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration) } } -// Helper function to convert bytes to kilobytes func bToKb(b uint64) uint64 { return b / 1024 } -// Global debug logger instance var ( debugLogger *DebugLogger debugOnce sync.Once ) -// GetDebugLogger returns the global debug logger instance func GetDebugLogger() *DebugLogger { debugOnce.Do(func() { debugLogger = NewDebugLogger() @@ -113,47 +99,38 @@ func GetDebugLogger() *DebugLogger { return debugLogger } -// Debug logs a debug-level message using the global debug logger func Debug(format string, v ...interface{}) { GetDebugLogger().Log(format, v...) } -// DebugWithContext logs a debug-level message with context using the global debug logger func DebugWithContext(context string, format string, v ...interface{}) { GetDebugLogger().LogWithContext(context, format, v...) } -// DebugFunction logs function entry and exit using the global debug logger func DebugFunction(functionName string, args ...interface{}) { GetDebugLogger().LogFunction(functionName, args...) } -// DebugVariable logs variable values using the global debug logger func DebugVariable(varName string, value interface{}) { GetDebugLogger().LogVariable(varName, value) } -// DebugState logs application state information using the global debug logger func DebugState(component string, state interface{}) { GetDebugLogger().LogState(component, state) } -// DebugSQL logs SQL queries using the global debug logger func DebugSQL(query string, args ...interface{}) { GetDebugLogger().LogSQL(query, args...) } -// DebugMemory logs memory usage information using the global debug logger func DebugMemory() { GetDebugLogger().LogMemory() } -// DebugGoroutines logs current number of goroutines using the global debug logger func DebugGoroutines() { GetDebugLogger().LogGoroutines() } -// DebugTiming logs timing information using the global debug logger func DebugTiming(operation string, duration interface{}) { GetDebugLogger().LogTiming(operation, duration) } diff --git a/local/utl/logging/error.go b/local/utl/logging/error.go index d347bbd..1737ef0 100644 --- a/local/utl/logging/error.go +++ b/local/utl/logging/error.go @@ -6,12 +6,10 @@ import ( "sync" ) -// ErrorLogger handles error-level logging type ErrorLogger struct { base *BaseLogger } -// NewErrorLogger creates a new error logger instance func NewErrorLogger() *ErrorLogger { base, _ := InitializeBase("error") return &ErrorLogger{ @@ -19,14 +17,12 @@ func NewErrorLogger() *ErrorLogger { } } -// Log writes an error-level log entry func (el *ErrorLogger) Log(format string, v ...interface{}) { if el.base != nil { el.base.Log(LogLevelError, format, v...) } } -// LogWithContext writes an error-level log entry with additional context func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) { if el.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -34,7 +30,6 @@ func (el *ErrorLogger) LogWithContext(context string, format string, v ...interf } } -// LogError logs an error object with optional message func (el *ErrorLogger) LogError(err error, message ...string) { if el.base != nil && err != nil { if len(message) > 0 { @@ -45,7 +40,6 @@ func (el *ErrorLogger) LogError(err error, message ...string) { } } -// LogWithStackTrace logs an error with stack trace func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) { if el.base != nil { // Get stack trace @@ -58,7 +52,6 @@ func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) { } } -// LogFatal logs a fatal error and exits the program func (el *ErrorLogger) LogFatal(format string, v ...interface{}) { if el.base != nil { el.base.Log(LogLevelError, "[FATAL] "+format, v...) @@ -66,13 +59,11 @@ func (el *ErrorLogger) LogFatal(format string, v ...interface{}) { } } -// Global error logger instance var ( errorLogger *ErrorLogger errorOnce sync.Once ) -// GetErrorLogger returns the global error logger instance func GetErrorLogger() *ErrorLogger { errorOnce.Do(func() { errorLogger = NewErrorLogger() @@ -80,27 +71,22 @@ func GetErrorLogger() *ErrorLogger { return errorLogger } -// Error logs an error-level message using the global error logger func Error(format string, v ...interface{}) { GetErrorLogger().Log(format, v...) } -// ErrorWithContext logs an error-level message with context using the global error logger func ErrorWithContext(context string, format string, v ...interface{}) { GetErrorLogger().LogWithContext(context, format, v...) } -// LogError logs an error object using the global error logger func LogError(err error, message ...string) { GetErrorLogger().LogError(err, message...) } -// ErrorWithStackTrace logs an error with stack trace using the global error logger func ErrorWithStackTrace(format string, v ...interface{}) { GetErrorLogger().LogWithStackTrace(format, v...) } -// Fatal logs a fatal error and exits the program using the global error logger func Fatal(format string, v ...interface{}) { GetErrorLogger().LogFatal(format, v...) } diff --git a/local/utl/logging/info.go b/local/utl/logging/info.go index 17553f0..611064b 100644 --- a/local/utl/logging/info.go +++ b/local/utl/logging/info.go @@ -5,12 +5,10 @@ import ( "sync" ) -// InfoLogger handles info-level logging type InfoLogger struct { base *BaseLogger } -// NewInfoLogger creates a new info logger instance func NewInfoLogger() *InfoLogger { base, _ := InitializeBase("info") return &InfoLogger{ @@ -18,14 +16,12 @@ func NewInfoLogger() *InfoLogger { } } -// Log writes an info-level log entry func (il *InfoLogger) Log(format string, v ...interface{}) { if il.base != nil { il.base.Log(LogLevelInfo, format, v...) } } -// LogWithContext writes an info-level log entry with additional context func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) { if il.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -33,55 +29,47 @@ func (il *InfoLogger) LogWithContext(context string, format string, v ...interfa } } -// LogStartup logs application startup information func (il *InfoLogger) LogStartup(component string, message string) { if il.base != nil { il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message) } } -// LogShutdown logs application shutdown information func (il *InfoLogger) LogShutdown(component string, message string) { if il.base != nil { il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message) } } -// LogOperation logs general operation information func (il *InfoLogger) LogOperation(operation string, details string) { if il.base != nil { il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details) } } -// LogStatus logs status changes or updates func (il *InfoLogger) LogStatus(component string, status string) { if il.base != nil { il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status) } } -// LogRequest logs incoming requests func (il *InfoLogger) LogRequest(method string, path string, userAgent string) { if il.base != nil { il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent) } } -// LogResponse logs outgoing responses func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) { if il.base != nil { il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration) } } -// Global info logger instance var ( infoLogger *InfoLogger infoOnce sync.Once ) -// GetInfoLogger returns the global info logger instance func GetInfoLogger() *InfoLogger { infoOnce.Do(func() { infoLogger = NewInfoLogger() @@ -89,42 +77,34 @@ func GetInfoLogger() *InfoLogger { return infoLogger } -// Info logs an info-level message using the global info logger func Info(format string, v ...interface{}) { GetInfoLogger().Log(format, v...) } -// InfoWithContext logs an info-level message with context using the global info logger func InfoWithContext(context string, format string, v ...interface{}) { GetInfoLogger().LogWithContext(context, format, v...) } -// InfoStartup logs application startup information using the global info logger func InfoStartup(component string, message string) { GetInfoLogger().LogStartup(component, message) } -// InfoShutdown logs application shutdown information using the global info logger func InfoShutdown(component string, message string) { GetInfoLogger().LogShutdown(component, message) } -// InfoOperation logs general operation information using the global info logger func InfoOperation(operation string, details string) { GetInfoLogger().LogOperation(operation, details) } -// InfoStatus logs status changes or updates using the global info logger func InfoStatus(component string, status string) { GetInfoLogger().LogStatus(component, status) } -// InfoRequest logs incoming requests using the global info logger func InfoRequest(method string, path string, userAgent string) { GetInfoLogger().LogRequest(method, path, userAgent) } -// InfoResponse logs outgoing responses using the global info logger func InfoResponse(method string, path string, statusCode int, duration string) { GetInfoLogger().LogResponse(method, path, statusCode, duration) } diff --git a/local/utl/logging/logger.go b/local/utl/logging/logger.go index 258c0a9..e0079c4 100644 --- a/local/utl/logging/logger.go +++ b/local/utl/logging/logger.go @@ -6,12 +6,10 @@ import ( ) var ( - // Legacy logger for backward compatibility logger *Logger once sync.Once ) -// Logger maintains backward compatibility with existing code type Logger struct { base *BaseLogger errorLogger *ErrorLogger @@ -20,8 +18,6 @@ type Logger struct { debugLogger *DebugLogger } -// Initialize creates or gets the singleton logger instance -// This maintains backward compatibility with existing code func Initialize() (*Logger, error) { var err error once.Do(func() { @@ -31,13 +27,11 @@ func Initialize() (*Logger, error) { } func newLogger() (*Logger, error) { - // Initialize the base logger baseLogger, err := InitializeBase("log") if err != nil { return nil, err } - // Create the legacy logger wrapper logger := &Logger{ base: baseLogger, errorLogger: GetErrorLogger(), @@ -49,7 +43,6 @@ func newLogger() (*Logger, error) { return logger, nil } -// Close closes the logger func (l *Logger) Close() error { if l.base != nil { return l.base.Close() @@ -57,7 +50,6 @@ func (l *Logger) Close() error { return nil } -// Legacy methods for backward compatibility func (l *Logger) log(level, format string, v ...interface{}) { if l.base != nil { l.base.LogWithCaller(LogLevel(level), 3, format, v...) @@ -94,13 +86,10 @@ func (l *Logger) Panic(format string) { } } -// Global convenience functions for backward compatibility -// These are now implemented in individual logger files to avoid redeclaration func LegacyInfo(format string, v ...interface{}) { if logger != nil { logger.Info(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetInfoLogger().Log(format, v...) } } @@ -109,7 +98,6 @@ func LegacyError(format string, v ...interface{}) { if logger != nil { logger.Error(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetErrorLogger().Log(format, v...) } } @@ -118,7 +106,6 @@ func LegacyWarn(format string, v ...interface{}) { if logger != nil { logger.Warn(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetWarnLogger().Log(format, v...) } } @@ -127,7 +114,6 @@ func LegacyDebug(format string, v ...interface{}) { if logger != nil { logger.Debug(format, v...) } else { - // Fallback to direct logger if legacy logger not initialized GetDebugLogger().Log(format, v...) } } @@ -136,55 +122,42 @@ func Panic(format string) { if logger != nil { logger.Panic(format) } else { - // Fallback to direct logger if legacy logger not initialized GetErrorLogger().LogFatal(format) } } -// Enhanced logging convenience functions -// These provide direct access to specialized logging functions - -// LogStartup logs application startup information func LogStartup(component string, message string) { GetInfoLogger().LogStartup(component, message) } -// LogShutdown logs application shutdown information func LogShutdown(component string, message string) { GetInfoLogger().LogShutdown(component, message) } -// LogOperation logs general operation information func LogOperation(operation string, details string) { GetInfoLogger().LogOperation(operation, details) } -// LogRequest logs incoming HTTP requests func LogRequest(method string, path string, userAgent string) { GetInfoLogger().LogRequest(method, path, userAgent) } -// LogResponse logs outgoing HTTP responses func LogResponse(method string, path string, statusCode int, duration string) { GetInfoLogger().LogResponse(method, path, statusCode, duration) } -// LogSQL logs SQL queries for debugging func LogSQL(query string, args ...interface{}) { GetDebugLogger().LogSQL(query, args...) } -// LogMemory logs memory usage information func LogMemory() { GetDebugLogger().LogMemory() } -// LogTiming logs timing information for performance debugging func LogTiming(operation string, duration interface{}) { GetDebugLogger().LogTiming(operation, duration) } -// GetLegacyLogger returns the legacy logger instance for backward compatibility func GetLegacyLogger() *Logger { if logger == nil { logger, _ = Initialize() @@ -192,21 +165,17 @@ func GetLegacyLogger() *Logger { return logger } -// InitializeLogging initializes all logging components func InitializeLogging() error { - // Initialize legacy logger for backward compatibility _, err := Initialize() if err != nil { return fmt.Errorf("failed to initialize legacy logger: %v", err) } - // Pre-initialize all logger types to ensure separate log files GetErrorLogger() GetWarnLogger() GetInfoLogger() GetDebugLogger() - // Log successful initialization Info("Logging system initialized successfully") return nil diff --git a/local/utl/logging/warn.go b/local/utl/logging/warn.go index 376d0ea..3044140 100644 --- a/local/utl/logging/warn.go +++ b/local/utl/logging/warn.go @@ -5,12 +5,10 @@ import ( "sync" ) -// WarnLogger handles warn-level logging type WarnLogger struct { base *BaseLogger } -// NewWarnLogger creates a new warn logger instance func NewWarnLogger() *WarnLogger { base, _ := InitializeBase("warn") return &WarnLogger{ @@ -18,14 +16,12 @@ func NewWarnLogger() *WarnLogger { } } -// Log writes a warn-level log entry func (wl *WarnLogger) Log(format string, v ...interface{}) { if wl.base != nil { wl.base.Log(LogLevelWarn, format, v...) } } -// LogWithContext writes a warn-level log entry with additional context func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) { if wl.base != nil { contextualFormat := fmt.Sprintf("[%s] %s", context, format) @@ -33,7 +29,6 @@ func (wl *WarnLogger) LogWithContext(context string, format string, v ...interfa } } -// LogDeprecation logs a deprecation warning func (wl *WarnLogger) LogDeprecation(feature string, alternative string) { if wl.base != nil { if alternative != "" { @@ -44,27 +39,23 @@ func (wl *WarnLogger) LogDeprecation(feature string, alternative string) { } } -// LogConfiguration logs configuration-related warnings func (wl *WarnLogger) LogConfiguration(setting string, message string) { if wl.base != nil { wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message) } } -// LogPerformance logs performance-related warnings func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) { if wl.base != nil { wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual) } } -// Global warn logger instance var ( warnLogger *WarnLogger warnOnce sync.Once ) -// GetWarnLogger returns the global warn logger instance func GetWarnLogger() *WarnLogger { warnOnce.Do(func() { warnLogger = NewWarnLogger() @@ -72,27 +63,22 @@ func GetWarnLogger() *WarnLogger { return warnLogger } -// Warn logs a warn-level message using the global warn logger func Warn(format string, v ...interface{}) { GetWarnLogger().Log(format, v...) } -// WarnWithContext logs a warn-level message with context using the global warn logger func WarnWithContext(context string, format string, v ...interface{}) { GetWarnLogger().LogWithContext(context, format, v...) } -// WarnDeprecation logs a deprecation warning using the global warn logger func WarnDeprecation(feature string, alternative string) { GetWarnLogger().LogDeprecation(feature, alternative) } -// WarnConfiguration logs configuration-related warnings using the global warn logger func WarnConfiguration(setting string, message string) { GetWarnLogger().LogConfiguration(setting, message) } -// WarnPerformance logs performance-related warnings using the global warn logger func WarnPerformance(operation string, threshold string, actual string) { GetWarnLogger().LogPerformance(operation, threshold, actual) } diff --git a/local/utl/network/port.go b/local/utl/network/port.go index ea2ce1e..42e0779 100644 --- a/local/utl/network/port.go +++ b/local/utl/network/port.go @@ -6,12 +6,10 @@ import ( "time" ) -// IsPortAvailable checks if a port is available for both TCP and UDP func IsPortAvailable(port int) bool { return IsTCPPortAvailable(port) && IsUDPPortAvailable(port) } -// IsTCPPortAvailable checks if a TCP port is available func IsTCPPortAvailable(port int) bool { addr := fmt.Sprintf(":%d", port) listener, err := net.Listen("tcp", addr) @@ -22,7 +20,6 @@ func IsTCPPortAvailable(port int) bool { return true } -// IsUDPPortAvailable checks if a UDP port is available func IsUDPPortAvailable(port int) bool { conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: port}) if err != nil { @@ -32,7 +29,6 @@ func IsUDPPortAvailable(port int) bool { return true } -// FindAvailablePort finds an available port starting from the given port func FindAvailablePort(startPort int) (int, error) { maxPort := 65535 for port := startPort; port <= maxPort; port++ { @@ -43,14 +39,12 @@ func FindAvailablePort(startPort int) (int, error) { return 0, fmt.Errorf("no available ports found between %d and %d", startPort, maxPort) } -// FindAvailablePortRange finds a range of consecutive available ports func FindAvailablePortRange(startPort, count int) ([]int, error) { maxPort := 65535 ports := make([]int, 0, count) currentPort := startPort for len(ports) < count && currentPort <= maxPort { - // Check if we have enough consecutive ports available available := true for i := 0; i < count-len(ports); i++ { if !IsPortAvailable(currentPort + i) { @@ -74,7 +68,6 @@ func FindAvailablePortRange(startPort, count int) ([]int, error) { return ports, nil } -// WaitForPortAvailable waits for a port to become available with timeout func WaitForPortAvailable(port int, timeout time.Duration) error { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { @@ -84,4 +77,4 @@ func WaitForPortAvailable(port int, timeout time.Duration) error { time.Sleep(100 * time.Millisecond) } return fmt.Errorf("timeout waiting for port %d to become available", port) -} \ No newline at end of file +} diff --git a/local/utl/password/password.go b/local/utl/password/password.go index ad2424f..82e127e 100644 --- a/local/utl/password/password.go +++ b/local/utl/password/password.go @@ -8,13 +8,10 @@ import ( ) const ( - // MinPasswordLength defines the minimum password length MinPasswordLength = 8 - // BcryptCost defines the cost factor for bcrypt hashing - BcryptCost = 12 + BcryptCost = 12 ) -// HashPassword hashes a plain text password using bcrypt func HashPassword(password string) (string, error) { if len(password) < MinPasswordLength { return "", errors.New("password must be at least 8 characters long") @@ -28,12 +25,10 @@ func HashPassword(password string) (string, error) { return string(hashedBytes), nil } -// VerifyPassword verifies a plain text password against a hashed password func VerifyPassword(hashedPassword, password string) error { return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) } -// ValidatePasswordStrength validates password complexity requirements func ValidatePasswordStrength(password string) error { if len(password) < MinPasswordLength { return errors.New("password must be at least 8 characters long") diff --git a/local/utl/server/server.go b/local/utl/server/server.go index ba5524e..9d82f6f 100644 --- a/local/utl/server/server.go +++ b/local/utl/server/server.go @@ -17,25 +17,23 @@ import ( func Start(di *dig.Container) *fiber.App { app := fiber.New(fiber.Config{ EnablePrintRoutes: true, - ReadTimeout: 20 * time.Minute, // Increased for long-running Steam operations - WriteTimeout: 20 * time.Minute, // Increased for long-running Steam operations - IdleTimeout: 25 * time.Minute, // Increased accordingly - BodyLimit: 10 * 1024 * 1024, // 10MB + ReadTimeout: 20 * time.Minute, + WriteTimeout: 20 * time.Minute, + IdleTimeout: 25 * time.Minute, + BodyLimit: 10 * 1024 * 1024, }) - // Initialize security middleware securityMW := security.NewSecurityMiddleware() - // Add security middleware stack app.Use(securityMW.SecurityHeaders()) app.Use(securityMW.LogSecurityEvents()) - app.Use(securityMW.TimeoutMiddleware(20 * time.Minute)) // Increased for Steam operations - app.Use(securityMW.RequestContextTimeout(20 * time.Minute)) // Increased for Steam operations - app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB + app.Use(securityMW.TimeoutMiddleware(20 * time.Minute)) + app.Use(securityMW.RequestContextTimeout(20 * time.Minute)) + app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) app.Use(securityMW.ValidateUserAgent()) app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data")) app.Use(securityMW.InputSanitization()) - app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global + app.Use(securityMW.RateLimit(100, 1*time.Minute)) app.Use(helmet.New()) @@ -62,7 +60,7 @@ func Start(di *dig.Container) *fiber.App { port := os.Getenv("PORT") if port == "" { - port = "3000" // Default port + port = "3000" } logging.Info("Starting server on port %s", port) diff --git a/local/utl/tracking/log_tailer.go b/local/utl/tracking/log_tailer.go index fad0304..9530334 100644 --- a/local/utl/tracking/log_tailer.go +++ b/local/utl/tracking/log_tailer.go @@ -7,11 +7,11 @@ import ( ) type LogTailer struct { - filePath string - handleLine func(string) - stopChan chan struct{} - isRunning bool - tracker *PositionTracker + filePath string + handleLine func(string) + stopChan chan struct{} + isRunning bool + tracker *PositionTracker } func NewLogTailer(filePath string, handleLine func(string)) *LogTailer { @@ -30,10 +30,9 @@ func (t *LogTailer) Start() { t.isRunning = true go func() { - // Load last position from tracker pos, err := t.tracker.LoadPosition() if err != nil { - pos = &LogPosition{} // Start from beginning if error + pos = &LogPosition{} } lastSize := pos.LastPosition @@ -43,7 +42,6 @@ func (t *LogTailer) Start() { t.isRunning = false return default: - // Try to open and read the file if file, err := os.Open(t.filePath); err == nil { stat, err := file.Stat() if err != nil { @@ -52,12 +50,10 @@ func (t *LogTailer) Start() { continue } - // If file was truncated, start from beginning if stat.Size() < lastSize { lastSize = 0 } - // Seek to last read position if lastSize > 0 { file.Seek(lastSize, 0) } @@ -66,9 +62,8 @@ func (t *LogTailer) Start() { for scanner.Scan() { line := scanner.Text() t.handleLine(line) - lastSize, _ = file.Seek(0, 1) // Get current position - - // Save position periodically + lastSize, _ = file.Seek(0, 1) + t.tracker.SavePosition(&LogPosition{ LastPosition: lastSize, LastRead: line, @@ -78,7 +73,6 @@ func (t *LogTailer) Start() { file.Close() } - // Wait before next attempt time.Sleep(time.Second) } } @@ -90,4 +84,4 @@ func (t *LogTailer) Stop() { return } close(t.stopChan) -} \ No newline at end of file +} diff --git a/local/utl/tracking/position_tracker.go b/local/utl/tracking/position_tracker.go index 3dc1075..13c8827 100644 --- a/local/utl/tracking/position_tracker.go +++ b/local/utl/tracking/position_tracker.go @@ -7,8 +7,8 @@ import ( ) type LogPosition struct { - LastPosition int64 `json:"last_position"` - LastRead string `json:"last_read"` + LastPosition int64 `json:"last_position"` + LastRead string `json:"last_read"` } type PositionTracker struct { @@ -16,11 +16,10 @@ type PositionTracker struct { } func NewPositionTracker(logPath string) *PositionTracker { - // Create position file in same directory as log file dir := filepath.Dir(logPath) base := filepath.Base(logPath) positionFile := filepath.Join(dir, "."+base+".position") - + return &PositionTracker{ positionFile: positionFile, } @@ -30,7 +29,6 @@ func (t *PositionTracker) LoadPosition() (*LogPosition, error) { data, err := os.ReadFile(t.positionFile) if err != nil { if os.IsNotExist(err) { - // Return empty position if file doesn't exist return &LogPosition{}, nil } return nil, err @@ -51,4 +49,4 @@ func (t *PositionTracker) SavePosition(pos *LogPosition) error { } return os.WriteFile(t.positionFile, data, 0644) -} \ No newline at end of file +} diff --git a/local/utl/tracking/tracking.go b/local/utl/tracking/tracking.go index 1605ae1..7dcd294 100644 --- a/local/utl/tracking/tracking.go +++ b/local/utl/tracking/tracking.go @@ -80,7 +80,7 @@ func TailLogFile(path string, callback func(string)) { file, _ := os.Open(path) defer file.Close() - file.Seek(0, os.SEEK_END) // Start at end of file + file.Seek(0, os.SEEK_END) reader := bufio.NewReader(file) for { @@ -88,7 +88,7 @@ func TailLogFile(path string, callback func(string)) { if err == nil { callback(line) } else { - time.Sleep(500 * time.Millisecond) // wait for new data + time.Sleep(500 * time.Millisecond) } } } diff --git a/local/utl/websocket/websocket.go b/local/utl/websocket/websocket.go new file mode 100644 index 0000000..f43fe81 --- /dev/null +++ b/local/utl/websocket/websocket.go @@ -0,0 +1,169 @@ +package websocket + +import ( + "acc-server-manager/local/model" + "acc-server-manager/local/utl/logging" + "encoding/json" + "sync" + "time" + + "github.com/gofiber/websocket/v2" + "github.com/google/uuid" +) + +type WebSocketConnection struct { + conn *websocket.Conn + serverID *uuid.UUID + userID *uuid.UUID +} + +type WebSocketService struct { + connections sync.Map + mu sync.RWMutex +} + +func NewWebSocketService() *WebSocketService { + return &WebSocketService{} +} + +func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) { + wsConn := &WebSocketConnection{ + conn: conn, + userID: userID, + } + ws.connections.Store(connID, wsConn) + logging.Info("WebSocket connection added: %s for user: %v", connID, userID) +} + +func (ws *WebSocketService) RemoveConnection(connID string) { + if conn, exists := ws.connections.LoadAndDelete(connID); exists { + if wsConn, ok := conn.(*WebSocketConnection); ok { + wsConn.conn.Close() + } + } + logging.Info("WebSocket connection removed: %s", connID) +} + +func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) { + if conn, exists := ws.connections.Load(connID); exists { + if wsConn, ok := conn.(*WebSocketConnection); ok { + wsConn.serverID = &serverID + } + } +} + +func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) { + stepMsg := model.StepMessage{ + Step: step, + Status: status, + Message: message, + Error: errorMsg, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeStep, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: stepMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) { + steamMsg := model.SteamOutputMessage{ + Output: output, + IsError: isError, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeSteamOutput, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: steamMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) { + errorMsg := model.ErrorMessage{ + Error: error, + Details: details, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeError, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: errorMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) { + completeMsg := model.CompleteMessage{ + ServerID: serverID, + Success: success, + Message: message, + } + + wsMsg := model.WebSocketMessage{ + Type: model.MessageTypeComplete, + ServerID: &serverID, + Timestamp: time.Now().Unix(), + Data: completeMsg, + } + + ws.broadcastToServer(serverID, wsMsg) +} + +func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) { + data, err := json.Marshal(message) + if err != nil { + logging.Error("Failed to marshal WebSocket message: %v", err) + return + } + + ws.connections.Range(func(key, value interface{}) bool { + if wsConn, ok := value.(*WebSocketConnection); ok { + if wsConn.serverID != nil && *wsConn.serverID == serverID { + if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { + logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) + ws.RemoveConnection(key.(string)) + } + } + } + return true + }) +} + +func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) { + data, err := json.Marshal(message) + if err != nil { + logging.Error("Failed to marshal WebSocket message: %v", err) + return + } + + ws.connections.Range(func(key, value interface{}) bool { + if wsConn, ok := value.(*WebSocketConnection); ok { + if wsConn.userID != nil && *wsConn.userID == userID { + if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil { + logging.Error("Failed to send WebSocket message to connection %s: %v", key, err) + ws.RemoveConnection(key.(string)) + } + } + } + return true + }) +} + +func (ws *WebSocketService) GetActiveConnections() int { + count := 0 + ws.connections.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +} diff --git a/swagger/docs.go b/swagger/docs.go index 5725dec..2477ed1 100644 --- a/swagger/docs.go +++ b/swagger/docs.go @@ -1,4 +1,3 @@ -// Package swagger Code generated by swaggo/swag. DO NOT EDIT package swagger import "github.com/swaggo/swag" @@ -2239,7 +2238,6 @@ const docTemplate = `{ } }` -// SwaggerInfo holds exported Swagger Info so clients can modify it var SwaggerInfo = &swag.Spec{ Version: "1.0", Host: "acc-api.jurmanovic.com", diff --git a/tests/auth_helper.go b/tests/auth_helper.go index eb9b795..2411d88 100644 --- a/tests/auth_helper.go +++ b/tests/auth_helper.go @@ -10,24 +10,19 @@ import ( "github.com/google/uuid" ) -// GenerateTestToken creates a JWT token for testing purposes func GenerateTestToken() (string, error) { - // Create test user user := &model.User{ ID: uuid.New(), Username: "test_user", RoleID: uuid.New(), } - // Use the environment JWT_SECRET for consistency with middleware testSecret := os.Getenv("JWT_SECRET") if testSecret == "" { - // Fallback to a test secret if env var is not set testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security" } jwtHandler := jwt.NewJWTHandler(testSecret) - // Generate JWT token token, err := jwtHandler.GenerateToken(user.ID.String()) if err != nil { return "", fmt.Errorf("failed to generate test token: %w", err) @@ -36,8 +31,6 @@ func GenerateTestToken() (string, error) { return token, nil } -// MustGenerateTestToken generates a test token and panics if it fails -// This is useful for test setup where failing to generate a token is a fatal error func MustGenerateTestToken() string { token, err := GenerateTestToken() if err != nil { @@ -46,24 +39,19 @@ func MustGenerateTestToken() string { return token } -// GenerateTestTokenWithExpiry creates a JWT token with a specific expiry time func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) { - // Use the environment JWT_SECRET for consistency with middleware testSecret := os.Getenv("JWT_SECRET") if testSecret == "" { - // Fallback to a test secret if env var is not set testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security" } jwtHandler := jwt.NewJWTHandler(testSecret) - // Create test user user := &model.User{ ID: uuid.New(), Username: "test_user", RoleID: uuid.New(), } - // Generate JWT token with custom expiry token, err := jwtHandler.GenerateTokenWithExpiry(user, expiryTime) if err != nil { return "", fmt.Errorf("failed to generate test token with expiry: %w", err) @@ -72,8 +60,6 @@ func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) { return token, nil } -// AddAuthHeader adds a test auth token to the request headers -// This is a convenience method for tests that need to authenticate requests func AddAuthHeader(headers map[string]string) (map[string]string, error) { token, err := GenerateTestToken() if err != nil { @@ -88,7 +74,6 @@ func AddAuthHeader(headers map[string]string) (map[string]string, error) { return headers, nil } -// MustAddAuthHeader adds a test auth token to the request headers and panics if it fails func MustAddAuthHeader(headers map[string]string) map[string]string { result, err := AddAuthHeader(headers) if err != nil { diff --git a/tests/mocks/auth_middleware_mock.go b/tests/mocks/auth_middleware_mock.go index fbb5841..7edf4a4 100644 --- a/tests/mocks/auth_middleware_mock.go +++ b/tests/mocks/auth_middleware_mock.go @@ -8,26 +8,20 @@ import ( "github.com/google/uuid" ) -// MockAuthMiddleware provides a test implementation of AuthMiddleware -// that can be used as a drop-in replacement for the real AuthMiddleware type MockAuthMiddleware struct{} -// NewMockAuthMiddleware creates a new MockAuthMiddleware func NewMockAuthMiddleware() *MockAuthMiddleware { return &MockAuthMiddleware{} } -// Authenticate is a middleware that allows all requests without authentication for testing func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error { - // Set a mock user ID in context mockUserID := uuid.New().String() ctx.Locals("userID", mockUserID) - // Set mock user info mockUserInfo := &middleware.CachedUserInfo{ UserID: mockUserID, Username: "test_user", - RoleName: "Admin", // Admin role to bypass permission checks + RoleName: "Admin", Permissions: map[string]bool{"*": true}, CachedAt: time.Now(), } @@ -38,21 +32,18 @@ func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error { return ctx.Next() } -// HasPermission is a middleware that allows all permission checks to pass for testing func (m *MockAuthMiddleware) HasPermission(requiredPermission string) fiber.Handler { return func(ctx *fiber.Ctx) error { return ctx.Next() } } -// AuthRateLimit is a test implementation that allows all requests func (m *MockAuthMiddleware) AuthRateLimit() fiber.Handler { return func(ctx *fiber.Ctx) error { return ctx.Next() } } -// RequireHTTPS is a test implementation that allows all HTTP requests func (m *MockAuthMiddleware) RequireHTTPS() fiber.Handler { return func(ctx *fiber.Ctx) error { return ctx.Next() diff --git a/tests/mocks/repository_mock.go b/tests/mocks/repository_mock.go index 1a6e27c..f29c15f 100644 --- a/tests/mocks/repository_mock.go +++ b/tests/mocks/repository_mock.go @@ -8,7 +8,6 @@ import ( "github.com/google/uuid" ) -// MockConfigRepository provides a mock implementation of ConfigRepository type MockConfigRepository struct { configs map[string]*model.Config shouldFailGet bool @@ -21,7 +20,6 @@ func NewMockConfigRepository() *MockConfigRepository { } } -// UpdateConfig mocks the UpdateConfig method func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config { if m.shouldFailUpdate { return nil @@ -36,18 +34,15 @@ func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.C return config } -// SetShouldFailUpdate configures the mock to fail on UpdateConfig calls func (m *MockConfigRepository) SetShouldFailUpdate(shouldFail bool) { m.shouldFailUpdate = shouldFail } -// GetConfig retrieves a config by server ID and config file func (m *MockConfigRepository) GetConfig(serverID uuid.UUID, configFile string) *model.Config { key := serverID.String() + "_" + configFile return m.configs[key] } -// MockServerRepository provides a mock implementation of ServerRepository type MockServerRepository struct { servers map[uuid.UUID]*model.Server shouldFailGet bool @@ -59,7 +54,6 @@ func NewMockServerRepository() *MockServerRepository { } } -// GetByID mocks the GetByID method func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*model.Server, error) { if m.shouldFailGet { return nil, errors.New("server not found") @@ -88,17 +82,14 @@ func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*mo return server, nil } -// AddServer adds a server to the mock repository func (m *MockServerRepository) AddServer(server *model.Server) { m.servers[server.ID] = server } -// SetShouldFailGet configures the mock to fail on GetByID calls func (m *MockServerRepository) SetShouldFailGet(shouldFail bool) { m.shouldFailGet = shouldFail } -// MockServerService provides a mock implementation of ServerService type MockServerService struct { startRuntimeCalled bool startRuntimeServer *model.Server @@ -108,23 +99,19 @@ func NewMockServerService() *MockServerService { return &MockServerService{} } -// StartAccServerRuntime mocks the StartAccServerRuntime method func (m *MockServerService) StartAccServerRuntime(server *model.Server) { m.startRuntimeCalled = true m.startRuntimeServer = server } -// WasStartRuntimeCalled returns whether StartAccServerRuntime was called func (m *MockServerService) WasStartRuntimeCalled() bool { return m.startRuntimeCalled } -// GetStartRuntimeServer returns the server passed to StartAccServerRuntime func (m *MockServerService) GetStartRuntimeServer() *model.Server { return m.startRuntimeServer } -// Reset resets the mock state func (m *MockServerService) Reset() { m.startRuntimeCalled = false m.startRuntimeServer = nil diff --git a/tests/mocks/state_history_mock.go b/tests/mocks/state_history_mock.go index 09e1740..bc811ea 100644 --- a/tests/mocks/state_history_mock.go +++ b/tests/mocks/state_history_mock.go @@ -8,7 +8,6 @@ import ( "github.com/google/uuid" ) -// MockStateHistoryRepository provides a mock implementation of StateHistoryRepository type MockStateHistoryRepository struct { stateHistories []model.StateHistory shouldFailGet bool @@ -21,7 +20,6 @@ func NewMockStateHistoryRepository() *MockStateHistoryRepository { } } -// GetAll mocks the GetAll method func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) { if m.shouldFailGet { return nil, errors.New("failed to get state history") @@ -37,13 +35,11 @@ func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.S return &filtered, nil } -// Insert mocks the Insert method func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *model.StateHistory) error { if m.shouldFailInsert { return errors.New("failed to insert state history") } - // Simulate BeforeCreate hook if stateHistory.ID == uuid.Nil { stateHistory.ID = uuid.New() } @@ -55,7 +51,6 @@ func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *m return nil } -// GetLastSessionID mocks the GetLastSessionID method func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) { for i := len(m.stateHistories) - 1; i >= 0; i-- { if m.stateHistories[i].ServerID == serverID { @@ -65,7 +60,6 @@ func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serve return uuid.Nil, nil } -// Helper methods for filtering func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter *model.StateHistoryFilter) bool { if filter == nil { return true @@ -93,7 +87,6 @@ func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter return true } -// Helper methods for testing configuration func (m *MockStateHistoryRepository) SetShouldFailGet(shouldFail bool) { m.shouldFailGet = shouldFail } @@ -102,7 +95,6 @@ func (m *MockStateHistoryRepository) SetShouldFailInsert(shouldFail bool) { m.shouldFailInsert = shouldFail } -// AddStateHistory adds a state history entry to the mock repository func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHistory) { if stateHistory.ID == uuid.Nil { stateHistory.ID = uuid.New() @@ -113,22 +105,18 @@ func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHis m.stateHistories = append(m.stateHistories, stateHistory) } -// GetCount returns the number of state history entries func (m *MockStateHistoryRepository) GetCount() int { return len(m.stateHistories) } -// Clear removes all state history entries func (m *MockStateHistoryRepository) Clear() { m.stateHistories = make([]model.StateHistory, 0) } -// GetSummaryStats calculates peak players, total sessions, and average players for mock data func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) { var stats model.StateHistoryStats var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) @@ -139,7 +127,6 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter return stats, nil } - // Calculate statistics sessionMap := make(map[string]bool) totalPlayers := 0 @@ -159,11 +146,9 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter return stats, nil } -// GetTotalPlaytime calculates total playtime in minutes for mock data func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) { var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) @@ -174,7 +159,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte return 0, nil } - // Group by session and calculate durations sessionMap := make(map[string][]model.StateHistory) for _, entry := range filteredEntries { sessionID := entry.SessionID.String() @@ -184,7 +168,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte totalMinutes := 0 for _, sessionEntries := range sessionMap { if len(sessionEntries) > 1 { - // Sort by date (simple approach for mock) minTime := sessionEntries[0].DateCreated maxTime := sessionEntries[0].DateCreated hasPlayers := false @@ -211,26 +194,22 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte return totalMinutes, nil } -// GetPlayerCountOverTime returns downsampled player count data for mock func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) { var points []model.PlayerCountPoint var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by hour (simple mock implementation) hourMap := make(map[string][]int) for _, entry := range filteredEntries { hourKey := entry.DateCreated.Format("2006-01-02 15") hourMap[hourKey] = append(hourMap[hourKey], entry.PlayerCount) } - // Calculate averages per hour for hourKey, counts := range hourMap { total := 0 for _, count := range counts { @@ -247,20 +226,17 @@ func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, return points, nil } -// GetSessionTypes counts sessions by type for mock func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) { var sessionTypes []model.SessionCount var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by session type - sessionMap := make(map[model.TrackSession]map[string]bool) // session -> sessionID -> bool + sessionMap := make(map[model.TrackSession]map[string]bool) for _, entry := range filteredEntries { if sessionMap[entry.Session] == nil { sessionMap[entry.Session] = make(map[string]bool) @@ -268,7 +244,6 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter sessionMap[entry.Session][entry.SessionID.String()] = true } - // Count unique sessions per type for sessionType, sessions := range sessionMap { sessionTypes = append(sessionTypes, model.SessionCount{ Name: sessionType, @@ -279,20 +254,17 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter return sessionTypes, nil } -// GetDailyActivity counts sessions per day for mock func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) { var dailyActivity []model.DailyActivity var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by day - dayMap := make(map[string]map[string]bool) // date -> sessionID -> bool + dayMap := make(map[string]map[string]bool) for _, entry := range filteredEntries { dateKey := entry.DateCreated.Format("2006-01-02") if dayMap[dateKey] == nil { @@ -301,7 +273,6 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte dayMap[dateKey][entry.SessionID.String()] = true } - // Count unique sessions per day for date, sessions := range dayMap { dailyActivity = append(dailyActivity, model.DailyActivity{ Date: date, @@ -312,26 +283,22 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte return dailyActivity, nil } -// GetRecentSessions retrieves recent sessions for mock func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) { var recentSessions []model.RecentSession var filteredEntries []model.StateHistory - // Filter entries for _, entry := range m.stateHistories { if m.matchesFilter(entry, filter) { filteredEntries = append(filteredEntries, entry) } } - // Group by session sessionMap := make(map[string][]model.StateHistory) for _, entry := range filteredEntries { sessionID := entry.SessionID.String() sessionMap[sessionID] = append(sessionMap[sessionID], entry) } - // Create recent sessions (limit to 10) count := 0 for _, entries := range sessionMap { if count >= 10 { @@ -339,7 +306,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt } if len(entries) > 0 { - // Find min/max dates and max players minDate := entries[0].DateCreated maxDate := entries[0].DateCreated maxPlayers := 0 @@ -356,7 +322,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt } } - // Only include sessions with players if maxPlayers > 0 { duration := int(maxDate.Sub(minDate).Minutes()) recentSessions = append(recentSessions, model.RecentSession{ diff --git a/tests/test_helper.go b/tests/test_helper.go index 81e429e..b6d0132 100644 --- a/tests/test_helper.go +++ b/tests/test_helper.go @@ -24,14 +24,12 @@ import ( "gorm.io/gorm/logger" ) -// TestHelper provides utilities for testing type TestHelper struct { DB *gorm.DB TempDir string TestData *TestData } -// TestData contains common test data structures type TestData struct { ServerID uuid.UUID Server *model.Server @@ -39,37 +37,29 @@ type TestData struct { SampleConfig *model.Configuration } -// SetTestEnv sets the required environment variables for tests func SetTestEnv() { - // Set required environment variables for testing os.Setenv("APP_SECRET", "test-secret-key-for-testing-123456") os.Setenv("APP_SECRET_CODE", "test-code-for-testing-123456789012") os.Setenv("ENCRYPTION_KEY", "12345678901234567890123456789012") os.Setenv("JWT_SECRET", "test-jwt-secret-key-for-testing-123456789012345678901234567890") os.Setenv("ACCESS_KEY", "test-access-key-for-testing") - // Set test-specific environment variables - os.Setenv("TESTING_ENV", "true") // Used to bypass + os.Setenv("TESTING_ENV", "true") configs.Init() } -// NewTestHelper creates a new test helper with in-memory database func NewTestHelper(t *testing.T) *TestHelper { - // Set required environment variables SetTestEnv() - // Create temporary directory for test files tempDir := t.TempDir() - // Create in-memory SQLite database for testing db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), // Suppress SQL logs in tests + Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Fatalf("Failed to connect to test database: %v", err) } - // Auto-migrate the schema err = db.AutoMigrate( &model.Server{}, &model.Config{}, @@ -79,7 +69,6 @@ func NewTestHelper(t *testing.T) *TestHelper { &model.StateHistory{}, ) - // Explicitly ensure tables exist with correct structure if !db.Migrator().HasTable(&model.StateHistory{}) { err = db.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -90,7 +79,6 @@ func NewTestHelper(t *testing.T) *TestHelper { t.Fatalf("Failed to migrate test database: %v", err) } - // Create test data testData := createTestData(t, tempDir) return &TestHelper{ @@ -100,11 +88,9 @@ func NewTestHelper(t *testing.T) *TestHelper { } } -// createTestData creates common test data structures func createTestData(t *testing.T, tempDir string) *TestData { serverID := uuid.New() - // Create sample server server := &model.Server{ ID: serverID, Name: "Test Server", @@ -115,13 +101,11 @@ func createTestData(t *testing.T, tempDir string) *TestData { FromSteamCMD: false, } - // Create server directory serverConfigDir := filepath.Join(tempDir, "server", "cfg") if err := os.MkdirAll(serverConfigDir, 0755); err != nil { t.Fatalf("Failed to create server config directory: %v", err) } - // Sample configuration files content configFiles := map[string]string{ "configuration.json": `{ "udpPort": "9231", @@ -212,7 +196,6 @@ func createTestData(t *testing.T, tempDir string) *TestData { }`, } - // Sample configuration struct sampleConfig := &model.Configuration{ UdpPort: model.IntString(9231), TcpPort: model.IntString(9232), @@ -230,14 +213,12 @@ func createTestData(t *testing.T, tempDir string) *TestData { } } -// CreateTestConfigFiles creates actual config files in the test directory func (th *TestHelper) CreateTestConfigFiles() error { serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg") for filename, content := range th.TestData.ConfigFiles { filePath := filepath.Join(serverConfigDir, filename) - // Encode content to UTF-16 LE BOM format as expected by the application utf16Content, err := EncodeUTF16LEBOM([]byte(content)) if err != nil { return err @@ -251,7 +232,6 @@ func (th *TestHelper) CreateTestConfigFiles() error { return nil } -// CreateMalformedConfigFile creates a config file with invalid JSON func (th *TestHelper) CreateMalformedConfigFile(filename string) error { serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg") filePath := filepath.Join(serverConfigDir, filename) @@ -259,67 +239,51 @@ func (th *TestHelper) CreateMalformedConfigFile(filename string) error { malformedJSON := `{ "udpPort": "9231", "tcpPort": "9232" - "maxConnections": "30" // Missing comma - invalid JSON + "maxConnections": "30" }` return os.WriteFile(filePath, []byte(malformedJSON), 0644) } -// RemoveConfigFile removes a config file to simulate missing file scenarios func (th *TestHelper) RemoveConfigFile(filename string) error { serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg") filePath := filepath.Join(serverConfigDir, filename) return os.Remove(filePath) } -// InsertTestServer inserts the test server into the database func (th *TestHelper) InsertTestServer() error { return th.DB.Create(th.TestData.Server).Error } -// CreateContext creates a test context func (th *TestHelper) CreateContext() context.Context { return context.Background() } -// CreateFiberCtx creates a fiber.Ctx for testing func (th *TestHelper) CreateFiberCtx() *fiber.Ctx { - // Create app and request for fiber context app := fiber.New() - // Create a dummy request that doesn't depend on external http objects req := httptest.NewRequest("GET", "/", nil) - // Create the fiber context from real request/response ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - // Store the original request for release later ctx.Locals("original-request", req) - // Return the context which can be safely used in tests return ctx } -// ReleaseFiberCtx properly releases a fiber context created with CreateFiberCtx func (th *TestHelper) ReleaseFiberCtx(app *fiber.App, ctx *fiber.Ctx) { if app != nil && ctx != nil { app.ReleaseCtx(ctx) } } -// Cleanup performs cleanup operations after tests func (th *TestHelper) Cleanup() { - // Close database connection if sqlDB, err := th.DB.DB(); err == nil { sqlDB.Close() } - // Temporary directory is automatically cleaned up by t.TempDir() } -// LoadTestEnvFile loads environment variables from a .env file for testing func LoadTestEnvFile() error { - // Try to load from .env file return godotenv.Load() } -// AssertNoError is a helper function to check for errors in tests func AssertNoError(t *testing.T, err error) { t.Helper() if err != nil { @@ -327,7 +291,6 @@ func AssertNoError(t *testing.T, err error) { } } -// AssertError is a helper function to check for expected errors func AssertError(t *testing.T, err error, expectedMsg string) { t.Helper() if err == nil { @@ -338,7 +301,6 @@ func AssertError(t *testing.T, err error, expectedMsg string) { } } -// AssertEqual checks if two values are equal func AssertEqual(t *testing.T, expected, actual interface{}) { t.Helper() if expected != actual { @@ -346,7 +308,6 @@ func AssertEqual(t *testing.T, expected, actual interface{}) { } } -// AssertNotNil checks if a value is not nil func AssertNotNil(t *testing.T, value interface{}) { t.Helper() if value == nil { @@ -354,12 +315,9 @@ func AssertNotNil(t *testing.T, value interface{}) { } } -// AssertNil checks if a value is nil func AssertNil(t *testing.T, value interface{}) { t.Helper() if value != nil { - // Special handling for interface values that contain nil but aren't nil themselves - // For example, (*jwt.Claims)(nil) is not equal to nil, but it contains nil switch v := value.(type) { case *interface{}: if v == nil || *v == nil { @@ -374,13 +332,11 @@ func AssertNil(t *testing.T, value interface{}) { } } -// EncodeUTF16LEBOM encodes UTF-8 bytes to UTF-16 LE BOM format func EncodeUTF16LEBOM(input []byte) ([]byte, error) { encoder := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) return transformBytes(encoder.NewEncoder(), input) } -// transformBytes applies a transform to input bytes func transformBytes(t transform.Transformer, input []byte) ([]byte, error) { var buf bytes.Buffer w := transform.NewWriter(&buf, t) @@ -396,7 +352,6 @@ func transformBytes(t transform.Transformer, input []byte) ([]byte, error) { return buf.Bytes(), nil } -// ErrorForTesting creates an error for testing purposes func ErrorForTesting(message string) error { return errors.New(message) } diff --git a/tests/testdata/state_history_data.go b/tests/testdata/state_history_data.go index 7b6962c..cc1bbd9 100644 --- a/tests/testdata/state_history_data.go +++ b/tests/testdata/state_history_data.go @@ -7,13 +7,11 @@ import ( "github.com/google/uuid" ) -// StateHistoryTestData provides simple test data generators type StateHistoryTestData struct { ServerID uuid.UUID BaseTime time.Time } -// NewStateHistoryTestData creates a new test data generator func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData { return &StateHistoryTestData{ ServerID: serverID, @@ -21,7 +19,6 @@ func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData { } } -// CreateStateHistory creates a basic state history entry func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, track string, playerCount int, sessionID uuid.UUID) model.StateHistory { return model.StateHistory{ ID: uuid.New(), @@ -36,7 +33,6 @@ func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, t } } -// CreateMultipleEntries creates multiple state history entries for the same session func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession, track string, playerCounts []int) []model.StateHistory { sessionID := uuid.New() var entries []model.StateHistory @@ -59,7 +55,6 @@ func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession return entries } -// CreateBasicFilter creates a basic filter for testing func CreateBasicFilter(serverID string) *model.StateHistoryFilter { return &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ @@ -68,7 +63,6 @@ func CreateBasicFilter(serverID string) *model.StateHistoryFilter { } } -// CreateFilterWithSession creates a filter with session type func CreateFilterWithSession(serverID string, session model.TrackSession) *model.StateHistoryFilter { return &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ @@ -78,7 +72,6 @@ func CreateFilterWithSession(serverID string, session model.TrackSession) *model } } -// LogLines contains sample ACC server log lines for testing var SampleLogLines = []string{ "[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", "[2024-01-15 14:30:30.456] 1 client(s) online", @@ -95,7 +88,6 @@ var SampleLogLines = []string{ "[2024-01-15 15:00:05.123] Session changed: RACE -> NONE", } -// ExpectedSessionChanges represents the expected session changes from parsing the sample log lines var ExpectedSessionChanges = []struct { From model.TrackSession To model.TrackSession @@ -106,5 +98,4 @@ var ExpectedSessionChanges = []struct { {model.SessionRace, model.SessionUnknown}, } -// ExpectedPlayerCounts represents the expected player counts from parsing the sample log lines var ExpectedPlayerCounts = []int{1, 3, 5, 8, 12, 15, 14, 0} diff --git a/tests/unit/controller/controller_simple_test.go b/tests/unit/controller/controller_simple_test.go index c75aec5..e01581c 100644 --- a/tests/unit/controller/controller_simple_test.go +++ b/tests/unit/controller/controller_simple_test.go @@ -14,12 +14,10 @@ import ( ) func TestController_JSONParsing_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test basic JSON parsing functionality app := fiber.New() app.Post("/test", func(c *fiber.Ctx) error { @@ -30,7 +28,6 @@ func TestController_JSONParsing_Success(t *testing.T) { return c.JSON(data) }) - // Prepare test data testData := map[string]interface{}{ "name": "test", "value": 123, @@ -38,34 +35,28 @@ func TestController_JSONParsing_Success(t *testing.T) { bodyBytes, err := json.Marshal(testData) tests.AssertNoError(t, err) - // Create request req := httptest.NewRequest("POST", "/test", bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, "test", response["name"]) - tests.AssertEqual(t, float64(123), response["value"]) // JSON numbers are float64 + tests.AssertEqual(t, float64(123), response["value"]) } func TestController_JSONParsing_InvalidJSON(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test handling of invalid JSON app := fiber.New() app.Post("/test", func(c *fiber.Ctx) error { @@ -76,39 +67,31 @@ func TestController_JSONParsing_InvalidJSON(t *testing.T) { return c.JSON(data) }) - // Create request with invalid JSON req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte("invalid json"))) req.Header.Set("Content-Type", "application/json") - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 400, resp.StatusCode) - // Parse error response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - - // Verify error response tests.AssertEqual(t, "Invalid JSON", response["error"]) } func TestController_UUIDValidation_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test UUID parameter validation app := fiber.New() app.Get("/test/:id", func(c *fiber.Ctx) error { id := c.Params("id") - // Validate UUID if _, err := uuid.Parse(id); err != nil { return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"}) } @@ -116,40 +99,33 @@ func TestController_UUIDValidation_Success(t *testing.T) { return c.JSON(fiber.Map{"id": id, "valid": true}) }) - // Create request with valid UUID validUUID := uuid.New().String() req := httptest.NewRequest("GET", "/test/"+validUUID, nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, validUUID, response["id"]) tests.AssertEqual(t, true, response["valid"]) } func TestController_UUIDValidation_InvalidUUID(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test handling of invalid UUID app := fiber.New() app.Get("/test/:id", func(c *fiber.Ctx) error { id := c.Params("id") - // Validate UUID if _, err := uuid.Parse(id); err != nil { return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"}) } @@ -157,32 +133,26 @@ func TestController_UUIDValidation_InvalidUUID(t *testing.T) { return c.JSON(fiber.Map{"id": id, "valid": true}) }) - // Create request with invalid UUID req := httptest.NewRequest("GET", "/test/invalid-uuid", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 400, resp.StatusCode) - // Parse error response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify error response tests.AssertEqual(t, "Invalid UUID", response["error"]) } func TestController_QueryParameters_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test query parameter handling app := fiber.New() app.Get("/test", func(c *fiber.Ctx) error { @@ -197,34 +167,28 @@ func TestController_QueryParameters_Success(t *testing.T) { }) }) - // Create request with query parameters req := httptest.NewRequest("GET", "/test?restart=true&override=false&format=xml", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, true, response["restart"]) tests.AssertEqual(t, false, response["override"]) tests.AssertEqual(t, "xml", response["format"]) } func TestController_HTTPMethods_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test different HTTP methods app := fiber.New() var getCalled, postCalled, putCalled, deleteCalled bool @@ -249,28 +213,24 @@ func TestController_HTTPMethods_Success(t *testing.T) { return c.JSON(fiber.Map{"method": "DELETE"}) }) - // Test GET req := httptest.NewRequest("GET", "/test", nil) resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) tests.AssertEqual(t, true, getCalled) - // Test POST req = httptest.NewRequest("POST", "/test", nil) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) tests.AssertEqual(t, true, postCalled) - // Test PUT req = httptest.NewRequest("PUT", "/test", nil) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) tests.AssertEqual(t, true, putCalled) - // Test DELETE req = httptest.NewRequest("DELETE", "/test", nil) resp, err = app.Test(req) tests.AssertNoError(t, err) @@ -279,12 +239,10 @@ func TestController_HTTPMethods_Success(t *testing.T) { } func TestController_ErrorHandling_StatusCodes(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test different error status codes app := fiber.New() app.Get("/400", func(c *fiber.Ctx) error { @@ -307,7 +265,6 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) { return c.Status(500).JSON(fiber.Map{"error": "Internal Server Error"}) }) - // Test different status codes testCases := []struct { path string code int @@ -328,12 +285,10 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) { } func TestController_ConfigurationModel_JSONSerialization(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test Configuration model JSON serialization app := fiber.New() app.Get("/config", func(c *fiber.Ctx) error { @@ -348,22 +303,18 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) { return c.JSON(config) }) - // Create request req := httptest.NewRequest("GET", "/config", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response var response model.Configuration body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, model.IntString(9231), response.UdpPort) tests.AssertEqual(t, model.IntString(9232), response.TcpPort) tests.AssertEqual(t, model.IntString(30), response.MaxConnections) @@ -373,73 +324,61 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) { } func TestController_UserModel_JSONSerialization(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test User model JSON serialization (password should be hidden) app := fiber.New() app.Get("/user", func(c *fiber.Ctx) error { user := &model.User{ ID: uuid.New(), Username: "testuser", - Password: "secret-password", // Should not appear in JSON + Password: "secret-password", RoleID: uuid.New(), } return c.JSON(user) }) - // Create request req := httptest.NewRequest("GET", "/user", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Parse response as raw JSON to check password is excluded body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) - // Verify password field is not in JSON if bytes.Contains(body, []byte("password")) || bytes.Contains(body, []byte("secret-password")) { t.Fatal("Password should not be included in JSON response") } - // Verify other fields are present if !bytes.Contains(body, []byte("username")) || !bytes.Contains(body, []byte("testuser")) { t.Fatal("Username should be included in JSON response") } } func TestController_MiddlewareChaining_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Test middleware chaining app := fiber.New() var middleware1Called, middleware2Called, handlerCalled bool - // Middleware 1 middleware1 := func(c *fiber.Ctx) error { middleware1Called = true c.Locals("middleware1", "executed") return c.Next() } - // Middleware 2 middleware2 := func(c *fiber.Ctx) error { middleware2Called = true c.Locals("middleware2", "executed") return c.Next() } - // Handler handler := func(c *fiber.Ctx) error { handlerCalled = true return c.JSON(fiber.Map{ @@ -451,27 +390,22 @@ func TestController_MiddlewareChaining_Success(t *testing.T) { app.Get("/test", middleware1, middleware2, handler) - // Create request req := httptest.NewRequest("GET", "/test", nil) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, 200, resp.StatusCode) - // Verify all were called tests.AssertEqual(t, true, middleware1Called) tests.AssertEqual(t, true, middleware2Called) tests.AssertEqual(t, true, handlerCalled) - // Parse response var response map[string]interface{} body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) err = json.Unmarshal(body, &response) tests.AssertNoError(t, err) - // Verify middleware values were passed tests.AssertEqual(t, "executed", response["middleware1"]) tests.AssertEqual(t, "executed", response["middleware2"]) tests.AssertEqual(t, "executed", response["handler"]) diff --git a/tests/unit/controller/helper.go b/tests/unit/controller/helper.go index 0b501e0..25c25ae 100644 --- a/tests/unit/controller/helper.go +++ b/tests/unit/controller/helper.go @@ -11,27 +11,20 @@ import ( "github.com/gofiber/fiber/v2" ) -// MockMiddleware simulates authentication for testing purposes type MockMiddleware struct{} -// GetTestAuthMiddleware returns a mock auth middleware that can be used in place of the real one -// This works because we're adding real authentication tokens to requests func GetTestAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *middleware.AuthMiddleware { - // Use environment JWT secrets for consistency with token generation jwtSecret := os.Getenv("JWT_SECRET") if jwtSecret == "" { jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security" } - + jwtHandler := jwt.NewJWTHandler(jwtSecret) - openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret) // Use same secret for test consistency - - // Cast our mock to the real type for testing - // This is a type-unsafe cast but works for testing because we're using real JWT tokens + openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret) + return middleware.NewAuthMiddleware(ms, cache, jwtHandler, openJWTHandler) } -// AddAuthToRequest adds a valid authentication token to a test request func AddAuthToRequest(req *fiber.Ctx) { token := tests.MustGenerateTestToken() req.Request().Header.Set("Authorization", "Bearer "+token) diff --git a/tests/unit/controller/state_history_controller_test.go b/tests/unit/controller/state_history_controller_test.go index a7e0d0d..748608f 100644 --- a/tests/unit/controller/state_history_controller_test.go +++ b/tests/unit/controller/state_history_controller_test.go @@ -23,13 +23,11 @@ import ( ) func TestStateHistoryController_GetAll_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // No need for DisableAuthentication, we'll use real auth tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -44,33 +42,26 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -83,13 +74,11 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) { } func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -104,7 +93,6 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data with different sessions testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) @@ -115,27 +103,21 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { err = repo.Insert(helper.CreateContext(), &raceHistory) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with session filter and authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=R", helper.TestData.ServerID.String()), nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -148,13 +130,11 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) { } func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -169,38 +149,30 @@ func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with no data and authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify empty response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) } func TestStateHistoryController_GetStatistics_Success(t *testing.T) { - // Skip this test as it requires more complex setup t.Skip("Skipping test due to UUID validation issues") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -215,10 +187,8 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data with multiple entries for statistics testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Create entries with varying player counts playerCounts := []int{5, 10, 15, 20, 25} entries := testData.CreateMultipleEntries(model.SessionRace, "spa", playerCounts) @@ -227,33 +197,26 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with valid serverID UUID validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil) req.Header.Set("Content-Type", "application/json") - // Add Authorization header for testing req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -261,7 +224,6 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { err = json.Unmarshal(body, &stats) tests.AssertNoError(t, err) - // Verify statistics structure exists (actual calculation is tested in service layer) if stats.PeakPlayers < 0 { t.Error("Expected non-negative peak players") } @@ -274,16 +236,13 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) { } func TestStateHistoryController_GetStatistics_NoData(t *testing.T) { - // Skip this test as it requires more complex setup t.Skip("Skipping test due to UUID validation issues") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -298,33 +257,26 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with valid serverID UUID validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil) req.Header.Set("Content-Type", "application/json") - // Add Authorization header for testing req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify response tests.AssertEqual(t, http.StatusOK, resp.StatusCode) - // Parse response body body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) @@ -332,23 +284,19 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) { err = json.Unmarshal(body, &stats) tests.AssertNoError(t, err) - // Verify empty statistics tests.AssertEqual(t, 0, stats.PeakPlayers) tests.AssertEqual(t, 0.0, stats.AveragePlayers) tests.AssertEqual(t, 0, stats.TotalSessions) } func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) { - // Skip this test as it requires more complex setup t.Skip("Skipping test due to UUID validation issues") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -363,42 +311,34 @@ func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Create request with invalid query parameters but with valid UUID validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s&min_players=invalid", validServerID), nil) req.Header.Set("Content-Type", "application/json") - // Add Authorization header for testing req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) - // Execute request resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify error response tests.AssertEqual(t, http.StatusBadRequest, resp.StatusCode) } func TestStateHistoryController_HTTPMethods(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -413,36 +353,30 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Test that only GET method is allowed for GetAll req := httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err := app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) - // Test that only GET method is allowed for GetStatistics req = httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) - // Test that PUT method is not allowed req = httptest.NewRequest("PUT", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) tests.AssertNoError(t, err) tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode) - // Test that DELETE method is not allowed req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) @@ -452,13 +386,11 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) { func TestStateHistoryController_ContentType(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -473,43 +405,36 @@ func TestStateHistoryController_ContentType(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Test GetAll endpoint with authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err := app.Test(req) tests.AssertNoError(t, err) - // Verify content type is JSON contentType := resp.Header.Get("Content-Type") if contentType != "application/json" { t.Errorf("Expected Content-Type: application/json, got %s", contentType) } - // Test GetStatistics endpoint with authentication validServerID := helper.TestData.ServerID.String() if validServerID == "" { - validServerID = uuid.New().String() // Generate a new valid UUID if needed + validServerID = uuid.New().String() } req = httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err = app.Test(req) tests.AssertNoError(t, err) - // Verify content type is JSON contentType = resp.Header.Get("Content-Type") if contentType != "application/json" { t.Errorf("Expected Content-Type: application/json, got %s", contentType) @@ -517,16 +442,13 @@ func TestStateHistoryController_ContentType(t *testing.T) { } func TestStateHistoryController_ResponseStructure(t *testing.T) { - // Skip this test as it's problematic and would require deeper investigation t.Skip("Skipping test due to response structure issues that need further investigation") - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() app := fiber.New() - // Using real JWT auth with tokens repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) @@ -541,7 +463,6 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) { inMemCache := cache.NewInMemoryCache() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -549,21 +470,17 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) { } } - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Setup routes routeGroups := &common.RouteGroups{ StateHistory: app.Group("/api/v1/state-history"), } - // Use a test auth middleware that works with the DisableAuthentication controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache)) - // Test GetAll response structure with authentication req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil) req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken()) resp, err := app.Test(req) @@ -572,24 +489,19 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) { body, err := io.ReadAll(resp.Body) tests.AssertNoError(t, err) - // Log the actual response for debugging t.Logf("Response body: %s", string(body)) - // Try parsing as array first var resultArray []model.StateHistory err = json.Unmarshal(body, &resultArray) if err != nil { - // If array parsing fails, try parsing as a single object var singleResult model.StateHistory err = json.Unmarshal(body, &singleResult) if err != nil { t.Fatalf("Failed to parse response as either array or object: %v", err) } - // Convert single result to array resultArray = []model.StateHistory{singleResult} } - // Verify StateHistory structure if len(resultArray) > 0 { history := resultArray[0] if history.ID == uuid.Nil { diff --git a/tests/unit/repository/state_history_repository_test.go b/tests/unit/repository/state_history_repository_test.go index d4f1ae5..dc7bedb 100644 --- a/tests/unit/repository/state_history_repository_test.go +++ b/tests/unit/repository/state_history_repository_test.go @@ -12,12 +12,10 @@ import ( ) func TestStateHistoryRepository_Insert_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -28,15 +26,12 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) - // Test Insert err := repo.Insert(ctx, &history) tests.AssertNoError(t, err) - // Verify ID was generated tests.AssertNotNil(t, history.ID) if history.ID == uuid.Nil { t.Error("Expected non-nil ID after insert") @@ -44,12 +39,10 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) { } func TestStateHistoryRepository_GetAll_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -60,10 +53,8 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Insert multiple entries playerCounts := []int{0, 5, 10, 15, 10, 5, 0} entries := testData.CreateMultipleEntries(model.SessionPractice, "spa", playerCounts) @@ -72,7 +63,6 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Test GetAll filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := repo.GetAll(ctx, filter) @@ -82,12 +72,10 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) { } func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -98,19 +86,16 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data with different sessions testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 15, uuid.New()) - // Insert both err := repo.Insert(ctx, &practiceHistory) tests.AssertNoError(t, err) err = repo.Insert(ctx, &raceHistory) tests.AssertNoError(t, err) - // Test GetAll with session filter filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace) result, err := repo.GetAll(ctx, filter) @@ -122,12 +107,10 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) { } func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -138,44 +121,34 @@ func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Insert multiple entries with different session IDs sessionID1 := uuid.New() sessionID2 := uuid.New() history1 := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID1) history2 := testData.CreateStateHistory(model.SessionRace, "spa", 10, sessionID2) - // Insert with a small delay to ensure ordering err := repo.Insert(ctx, &history1) tests.AssertNoError(t, err) - time.Sleep(1 * time.Millisecond) // Ensure different timestamps + time.Sleep(1 * time.Millisecond) err = repo.Insert(ctx, &history2) tests.AssertNoError(t, err) - - // Test GetLastSessionID - should return the most recent session ID lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) - // Should be sessionID2 since it was inserted last - // We should get the most recently inserted session ID, but the exact value doesn't matter - // Just check that it's not nil and that it's a valid UUID if lastSessionID == uuid.Nil { t.Fatal("Expected non-nil UUID for last session ID") } } func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -186,19 +159,16 @@ func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Test GetLastSessionID with no data lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) tests.AssertEqual(t, uuid.Nil, lastSessionID) } func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -209,14 +179,11 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data with varying player counts testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Create entries with different sessions and player counts sessionID1 := uuid.New() sessionID2 := uuid.New() - // Practice session: 5, 10, 15 players practiceEntries := testData.CreateMultipleEntries(model.SessionPractice, "spa", []int{5, 10, 15}) for i := range practiceEntries { practiceEntries[i].SessionID = sessionID1 @@ -224,7 +191,6 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Race session: 20, 25, 30 players raceEntries := testData.CreateMultipleEntries(model.SessionRace, "spa", []int{20, 25, 30}) for i := range raceEntries { raceEntries[i].SessionID = sessionID2 @@ -232,17 +198,14 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Test GetSummaryStats filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) stats, err := repo.GetSummaryStats(ctx, filter) tests.AssertNoError(t, err) - // Verify stats are calculated correctly - tests.AssertEqual(t, 30, stats.PeakPlayers) // Maximum player count - tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions + tests.AssertEqual(t, 30, stats.PeakPlayers) + tests.AssertEqual(t, 2, stats.TotalSessions) - // Average should be (5+10+15+20+25+30)/6 = 17.5 expectedAverage := float64(5+10+15+20+25+30) / 6.0 if stats.AveragePlayers != expectedAverage { t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers) @@ -250,12 +213,10 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) { } func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -266,25 +227,21 @@ func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Test GetSummaryStats with no data filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) stats, err := repo.GetSummaryStats(ctx, filter) tests.AssertNoError(t, err) - // Verify stats are zero for empty dataset tests.AssertEqual(t, 0, stats.PeakPlayers) tests.AssertEqual(t, 0.0, stats.AveragePlayers) tests.AssertEqual(t, 0, stats.TotalSessions) } func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { - // Setup environment and test helper tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -294,13 +251,10 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - - // Create test data spanning a time range sessionID := uuid.New() baseTime := time.Now().UTC() - // Create entries spanning 1 hour with players > 0 entries := []model.StateHistory{ { ID: uuid.New(), @@ -342,7 +296,6 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Test GetTotalPlaytime filter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: helper.TestData.ServerID.String(), @@ -355,20 +308,16 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) { playtime, err := repo.GetTotalPlaytime(ctx, filter) tests.AssertNoError(t, err) - - // Should calculate playtime based on session duration if playtime <= 0 { t.Error("Expected positive playtime for session with multiple entries") } } func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { - // Test concurrent database operations tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -379,10 +328,8 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -390,7 +337,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } } - // Create and insert initial entry to ensure table exists and is properly set up initialHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(ctx, &initialHistory) if err != nil { @@ -399,7 +345,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { done := make(chan bool, 3) - // Concurrent inserts go func() { defer func() { done <- true @@ -412,7 +357,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } }() - // Concurrent reads go func() { defer func() { done <- true @@ -425,7 +369,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } }() - // Concurrent GetLastSessionID go func() { defer func() { done <- true @@ -437,19 +380,16 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) { } }() - // Wait for all operations to complete for i := 0; i < 3; i++ { <-done } } func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) { - // Test edge cases with filters tests.SetTestEnv() helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -460,15 +400,11 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) ctx := helper.CreateContext() - // Insert a test record to ensure the table is properly set up testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(ctx, &history) tests.AssertNoError(t, err) - // Skip nil filter test as it might not be supported by the repository implementation - - // Test with server ID filter - this should work serverFilter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: helper.TestData.ServerID.String(), @@ -478,7 +414,6 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) { tests.AssertNoError(t, err) tests.AssertNotNil(t, result) - // Test with invalid server ID in summary stats invalidFilter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: "invalid-uuid", diff --git a/tests/unit/service/auth_simple_test.go b/tests/unit/service/auth_simple_test.go index 361542c..786a8be 100644 --- a/tests/unit/service/auth_simple_test.go +++ b/tests/unit/service/auth_simple_test.go @@ -12,30 +12,24 @@ import ( ) func TestJWT_GenerateAndValidateToken(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) - - // Create test user user := &model.User{ ID: uuid.New(), Username: "testuser", RoleID: uuid.New(), } - // Test JWT generation token, err := jwtHandler.GenerateToken(user.ID.String()) tests.AssertNoError(t, err) tests.AssertNotNil(t, token) - // Verify token is not empty if token == "" { t.Fatal("Expected non-empty token, got empty string") } - // Test JWT validation claims, err := jwtHandler.ValidateToken(token) tests.AssertNoError(t, err) tests.AssertNotNil(t, claims) @@ -43,80 +37,66 @@ func TestJWT_GenerateAndValidateToken(t *testing.T) { } func TestJWT_ValidateToken_InvalidToken(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) - // Test with invalid token claims, err := jwtHandler.ValidateToken("invalid-token") if err == nil { t.Fatal("Expected error for invalid token, got nil") } - // Direct nil check to avoid the interface wrapping issue if claims != nil { t.Fatalf("Expected nil claims, got %v", claims) } } func TestJWT_ValidateToken_EmptyToken(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) - // Test with empty token claims, err := jwtHandler.ValidateToken("") if err == nil { t.Fatal("Expected error for empty token, got nil") } - // Direct nil check to avoid the interface wrapping issue if claims != nil { t.Fatalf("Expected nil claims, got %v", claims) } } func TestUser_VerifyPassword_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test user user := &model.User{ ID: uuid.New(), Username: "testuser", RoleID: uuid.New(), } - // Hash password manually (simulating what BeforeCreate would do) plainPassword := "password123" hashedPassword, err := password.HashPassword(plainPassword) tests.AssertNoError(t, err) user.Password = hashedPassword - // Test password verification - should succeed err = user.VerifyPassword(plainPassword) tests.AssertNoError(t, err) } func TestUser_VerifyPassword_Failure(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test user user := &model.User{ ID: uuid.New(), Username: "testuser", RoleID: uuid.New(), } - // Hash password manually hashedPassword, err := password.HashPassword("correct_password") tests.AssertNoError(t, err) user.Password = hashedPassword - // Test password verification with wrong password - should fail err = user.VerifyPassword("wrong_password") if err == nil { t.Fatal("Expected error for wrong password, got nil") @@ -124,11 +104,9 @@ func TestUser_VerifyPassword_Failure(t *testing.T) { } func TestUser_Validate_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create valid user user := &model.User{ ID: uuid.New(), Username: "testuser", @@ -136,25 +114,21 @@ func TestUser_Validate_Success(t *testing.T) { RoleID: uuid.New(), } - // Test validation - should succeed err := user.Validate() tests.AssertNoError(t, err) } func TestUser_Validate_MissingUsername(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create user without username user := &model.User{ ID: uuid.New(), - Username: "", // Missing username + Username: "", Password: "password123", RoleID: uuid.New(), } - // Test validation - should fail err := user.Validate() if err == nil { t.Fatal("Expected error for missing username, got nil") @@ -162,19 +136,16 @@ func TestUser_Validate_MissingUsername(t *testing.T) { } func TestUser_Validate_MissingPassword(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create user without password user := &model.User{ ID: uuid.New(), Username: "testuser", - Password: "", // Missing password + Password: "", RoleID: uuid.New(), } - // Test validation - should fail err := user.Validate() if err == nil { t.Fatal("Expected error for missing password, got nil") @@ -182,24 +153,19 @@ func TestUser_Validate_MissingPassword(t *testing.T) { } func TestPassword_HashAndVerify(t *testing.T) { - // Test password hashing and verification directly plainPassword := "test_password_123" - // Hash password hashedPassword, err := password.HashPassword(plainPassword) tests.AssertNoError(t, err) tests.AssertNotNil(t, hashedPassword) - // Verify hashed password is not the same as plain password if hashedPassword == plainPassword { t.Fatal("Hashed password should not equal plain password") } - // Verify correct password err = password.VerifyPassword(hashedPassword, plainPassword) tests.AssertNoError(t, err) - // Verify wrong password fails err = password.VerifyPassword(hashedPassword, "wrong_password") if err == nil { t.Fatal("Expected error for wrong password, got nil") @@ -215,7 +181,7 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) { {"Valid password", "StrongPassword123!", false}, {"Too short", "123", true}, {"Empty password", "", true}, - {"Medium password", "password123", false}, // Depends on validation rules + {"Medium password", "password123", false}, } for _, tc := range testCases { @@ -235,7 +201,6 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) { } func TestRole_Model(t *testing.T) { - // Test Role model structure permissions := []model.Permission{ {ID: uuid.New(), Name: "read"}, {ID: uuid.New(), Name: "write"}, @@ -248,7 +213,6 @@ func TestRole_Model(t *testing.T) { Permissions: permissions, } - // Verify role structure tests.AssertEqual(t, "Test Role", role.Name) tests.AssertEqual(t, 3, len(role.Permissions)) tests.AssertEqual(t, "read", role.Permissions[0].Name) @@ -257,19 +221,16 @@ func TestRole_Model(t *testing.T) { } func TestPermission_Model(t *testing.T) { - // Test Permission model structure permission := &model.Permission{ ID: uuid.New(), Name: "test_permission", } - // Verify permission structure tests.AssertEqual(t, "test_permission", permission.Name) tests.AssertNotNil(t, permission.ID) } func TestUser_WithRole_Model(t *testing.T) { - // Test User model with Role relationship permissions := []model.Permission{ {ID: uuid.New(), Name: "read"}, {ID: uuid.New(), Name: "write"}, @@ -289,7 +250,6 @@ func TestUser_WithRole_Model(t *testing.T) { Role: role, } - // Verify user-role relationship tests.AssertEqual(t, "testuser", user.Username) tests.AssertEqual(t, role.ID, user.RoleID) tests.AssertEqual(t, "User", user.Role.Name) diff --git a/tests/unit/service/cache_service_test.go b/tests/unit/service/cache_service_test.go index 4cadf1d..95862af 100644 --- a/tests/unit/service/cache_service_test.go +++ b/tests/unit/service/cache_service_test.go @@ -11,28 +11,22 @@ import ( ) func TestInMemoryCache_Set_Get_Success(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" duration := 5 * time.Minute - // Set value in cache c.Set(key, value, duration) - // Get value from cache result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) } func TestInMemoryCache_Get_NotFound(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Try to get non-existent key result, found := c.Get("non-existent-key") tests.AssertEqual(t, false, found) if result != nil { @@ -41,43 +35,33 @@ func TestInMemoryCache_Get_NotFound(t *testing.T) { } func TestInMemoryCache_Set_Get_NoExpiration(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" - // Set value without expiration (duration = 0) c.Set(key, value, 0) - // Get value from cache result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) } func TestInMemoryCache_Expiration(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" - duration := 1 * time.Millisecond // Very short duration + duration := 1 * time.Millisecond - // Set value in cache c.Set(key, value, duration) - // Verify it's initially there result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) - // Wait for expiration time.Sleep(2 * time.Millisecond) - // Try to get expired value result, found = c.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -86,26 +70,20 @@ func TestInMemoryCache_Expiration(t *testing.T) { } func TestInMemoryCache_Delete(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value := "test-value" duration := 5 * time.Minute - // Set value in cache c.Set(key, value, duration) - // Verify it's there result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value, result) - // Delete the key c.Delete(key) - // Verify it's gone result, found = c.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -114,37 +92,29 @@ func TestInMemoryCache_Delete(t *testing.T) { } func TestInMemoryCache_Overwrite(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data key := "test-key" value1 := "test-value-1" value2 := "test-value-2" duration := 5 * time.Minute - // Set first value c.Set(key, value1, duration) - // Verify first value result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value1, result) - // Overwrite with second value c.Set(key, value2, duration) - // Verify second value result, found = c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, value2, result) } func TestInMemoryCache_Multiple_Keys(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test data testData := map[string]string{ "key1": "value1", "key2": "value2", @@ -152,22 +122,18 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) { } duration := 5 * time.Minute - // Set multiple values for key, value := range testData { c.Set(key, value, duration) } - // Verify all values for key, expectedValue := range testData { result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, expectedValue, result) } - // Delete one key c.Delete("key2") - // Verify key2 is gone but others remain result, found := c.Get("key2") tests.AssertEqual(t, false, found) if result != nil { @@ -184,10 +150,8 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) { } func TestInMemoryCache_Complex_Objects(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test with complex object (User struct) user := &model.User{ ID: uuid.New(), Username: "testuser", @@ -196,15 +160,12 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) { key := "user:" + user.ID.String() duration := 5 * time.Minute - // Set user in cache c.Set(key, user, duration) - // Get user from cache result, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, result) - // Verify it's the same user cachedUser, ok := result.(*model.User) tests.AssertEqual(t, true, ok) tests.AssertEqual(t, user.ID, cachedUser.ID) @@ -212,33 +173,27 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) { } func TestInMemoryCache_GetOrSet_CacheHit(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Pre-populate cache key := "test-key" expectedValue := "cached-value" c.Set(key, expectedValue, 5*time.Minute) - // Track if fetcher is called fetcherCalled := false fetcher := func() (string, error) { fetcherCalled = true return "fetcher-value", nil } - // Use GetOrSet - should return cached value result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher) tests.AssertNoError(t, err) tests.AssertEqual(t, expectedValue, result) - tests.AssertEqual(t, false, fetcherCalled) // Fetcher should not be called + tests.AssertEqual(t, false, fetcherCalled) } func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Track if fetcher is called fetcherCalled := false expectedValue := "fetcher-value" fetcher := func() (string, error) { @@ -248,35 +203,29 @@ func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) { key := "test-key" - // Use GetOrSet - should call fetcher and cache result result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher) tests.AssertNoError(t, err) tests.AssertEqual(t, expectedValue, result) - tests.AssertEqual(t, true, fetcherCalled) // Fetcher should be called + tests.AssertEqual(t, true, fetcherCalled) - // Verify value is now cached cachedResult, found := c.Get(key) tests.AssertEqual(t, true, found) tests.AssertEqual(t, expectedValue, cachedResult) } func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Fetcher that returns error fetcher := func() (string, error) { return "", tests.ErrorForTesting("fetcher error") } key := "test-key" - // Use GetOrSet - should return error result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher) tests.AssertError(t, err, "") tests.AssertEqual(t, "", result) - // Verify nothing is cached cachedResult, found := c.Get(key) tests.AssertEqual(t, false, found) if cachedResult != nil { @@ -285,10 +234,8 @@ func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) { } func TestInMemoryCache_TypeSafety(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test type safety with GetOrSet userFetcher := func() (*model.User, error) { return &model.User{ ID: uuid.New(), @@ -298,13 +245,11 @@ func TestInMemoryCache_TypeSafety(t *testing.T) { key := "user-key" - // Use GetOrSet with User type user, err := cache.GetOrSet(c, key, 5*time.Minute, userFetcher) tests.AssertNoError(t, err) tests.AssertNotNil(t, user) tests.AssertEqual(t, "testuser", user.Username) - // Verify correct type is cached cachedResult, found := c.Get(key) tests.AssertEqual(t, true, found) cachedUser, ok := cachedResult.(*model.User) @@ -313,26 +258,21 @@ func TestInMemoryCache_TypeSafety(t *testing.T) { } func TestInMemoryCache_Concurrent_Access(t *testing.T) { - // Setup c := cache.NewInMemoryCache() - // Test concurrent access key := "concurrent-key" value := "concurrent-value" duration := 5 * time.Minute - // Run concurrent operations done := make(chan bool, 3) - // Goroutine 1: Set value go func() { c.Set(key, value, duration) done <- true }() - // Goroutine 2: Get value go func() { - time.Sleep(1 * time.Millisecond) // Small delay to ensure Set happens first + time.Sleep(1 * time.Millisecond) result, found := c.Get(key) if found { tests.AssertEqual(t, value, result) @@ -340,19 +280,16 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) { done <- true }() - // Goroutine 3: Delete value go func() { - time.Sleep(2 * time.Millisecond) // Delay to ensure Set and Get happen first + time.Sleep(2 * time.Millisecond) c.Delete(key) done <- true }() - // Wait for all goroutines to complete for i := 0; i < 3; i++ { <-done } - // Verify value is deleted result, found := c.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -361,7 +298,6 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) { } func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -371,14 +307,12 @@ func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) { serviceName := "test-service" - // Initial call - should need refresh status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusUnknown, status) tests.AssertEqual(t, true, needsRefresh) } func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -389,17 +323,14 @@ func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) { serviceName := "test-service" expectedStatus := model.StatusRunning - // Update status cache.UpdateStatus(serviceName, expectedStatus) - // Get status - should return cached value status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, expectedStatus, status) tests.AssertEqual(t, false, needsRefresh) } func TestServerStatusCache_Throttling(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 100 * time.Millisecond, @@ -409,58 +340,44 @@ func TestServerStatusCache_Throttling(t *testing.T) { serviceName := "test-service" - // Update status cache.UpdateStatus(serviceName, model.StatusRunning) - // Immediate call - should return cached value status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Call within throttle time - should return cached/default status status, needsRefresh = cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Wait for throttle time to pass time.Sleep(150 * time.Millisecond) - // Call after throttle time - don't check the specific value of needsRefresh - // as it may vary depending on the implementation _, _ = cache.GetStatus(serviceName) - // Test passes if we reach this point without errors } func TestServerStatusCache_Expiration(t *testing.T) { - // Setup config := model.CacheConfig{ - ExpirationTime: 50 * time.Millisecond, // Very short expiration + ExpirationTime: 50 * time.Millisecond, ThrottleTime: 10 * time.Millisecond, DefaultStatus: model.StatusUnknown, } cache := model.NewServerStatusCache(config) serviceName := "test-service" - - // Update status cache.UpdateStatus(serviceName, model.StatusRunning) - // Immediate call - should return cached value status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Wait for expiration time.Sleep(60 * time.Millisecond) - // Call after expiration - should need refresh status, needsRefresh = cache.GetStatus(serviceName) tests.AssertEqual(t, true, needsRefresh) } func TestServerStatusCache_InvalidateStatus(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -470,25 +387,20 @@ func TestServerStatusCache_InvalidateStatus(t *testing.T) { serviceName := "test-service" - // Update status cache.UpdateStatus(serviceName, model.StatusRunning) - // Verify it's cached status, needsRefresh := cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) - // Invalidate status cache.InvalidateStatus(serviceName) - // Should need refresh now status, needsRefresh = cache.GetStatus(serviceName) tests.AssertEqual(t, model.StatusUnknown, status) tests.AssertEqual(t, true, needsRefresh) } func TestServerStatusCache_Clear(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -496,23 +408,19 @@ func TestServerStatusCache_Clear(t *testing.T) { } cache := model.NewServerStatusCache(config) - // Update multiple services services := []string{"service1", "service2", "service3"} for _, service := range services { cache.UpdateStatus(service, model.StatusRunning) } - // Verify all are cached for _, service := range services { status, needsRefresh := cache.GetStatus(service) tests.AssertEqual(t, model.StatusRunning, status) tests.AssertEqual(t, false, needsRefresh) } - // Clear cache cache.Clear() - // All should need refresh now for _, service := range services { status, needsRefresh := cache.GetStatus(service) tests.AssertEqual(t, model.StatusUnknown, status) @@ -521,30 +429,23 @@ func TestServerStatusCache_Clear(t *testing.T) { } func TestLookupCache_SetGetClear(t *testing.T) { - // Setup cache := model.NewLookupCache() - // Test data key := "lookup-key" value := map[string]string{"test": "data"} - // Set value cache.Set(key, value) - // Get value result, found := cache.Get(key) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, result) - // Verify it's the same data resultMap, ok := result.(map[string]string) tests.AssertEqual(t, true, ok) tests.AssertEqual(t, "data", resultMap["test"]) - // Clear cache cache.Clear() - // Should be gone now result, found = cache.Get(key) tests.AssertEqual(t, false, found) if result != nil { @@ -553,7 +454,6 @@ func TestLookupCache_SetGetClear(t *testing.T) { } func TestServerConfigCache_Configuration(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -571,17 +471,14 @@ func TestServerConfigCache_Configuration(t *testing.T) { ConfigVersion: model.IntString(1), } - // Initial get - should miss result, found := cache.GetConfiguration(serverID) tests.AssertEqual(t, false, found) if result != nil { t.Fatal("Expected nil result, got non-nil") } - // Update cache cache.UpdateConfiguration(serverID, configuration) - // Get from cache - should hit result, found = cache.GetConfiguration(serverID) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, result) @@ -590,7 +487,6 @@ func TestServerConfigCache_Configuration(t *testing.T) { } func TestServerConfigCache_InvalidateServerCache(t *testing.T) { - // Setup config := model.CacheConfig{ ExpirationTime: 5 * time.Minute, ThrottleTime: 1 * time.Second, @@ -602,11 +498,9 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) { configuration := model.Configuration{UdpPort: model.IntString(9231)} assistRules := model.AssistRules{StabilityControlLevelMax: model.IntString(0)} - // Update multiple configs for server cache.UpdateConfiguration(serverID, configuration) cache.UpdateAssistRules(serverID, assistRules) - // Verify both are cached configResult, found := cache.GetConfiguration(serverID) tests.AssertEqual(t, true, found) tests.AssertNotNil(t, configResult) @@ -615,10 +509,8 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) { tests.AssertEqual(t, true, found) tests.AssertNotNil(t, assistResult) - // Invalidate server cache cache.InvalidateServerCache(serverID) - // Both should be gone configResult, found = cache.GetConfiguration(serverID) tests.AssertEqual(t, false, found) if configResult != nil { diff --git a/tests/unit/service/config_service_test.go b/tests/unit/service/config_service_test.go index af378d0..c75fb50 100644 --- a/tests/unit/service/config_service_test.go +++ b/tests/unit/service/config_service_test.go @@ -12,25 +12,20 @@ import ( ) func TestConfigService_GetConfiguration_ValidFile(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetConfiguration config, err := configService.GetConfiguration(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, config) - // Verify the result is the expected configuration tests.AssertEqual(t, model.IntString(9231), config.UdpPort) tests.AssertEqual(t, model.IntString(9232), config.TcpPort) tests.AssertEqual(t, model.IntString(30), config.MaxConnections) @@ -40,21 +35,17 @@ func TestConfigService_GetConfiguration_ValidFile(t *testing.T) { } func TestConfigService_GetConfiguration_MissingFile(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create server directory but no config files serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg") err := os.MkdirAll(serverConfigDir, 0755) tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetConfiguration for missing file config, err := configService.GetConfiguration(helper.TestData.Server) if err == nil { t.Fatal("Expected error for missing file, got nil") @@ -65,25 +56,20 @@ func TestConfigService_GetConfiguration_MissingFile(t *testing.T) { } func TestConfigService_GetEventConfig_ValidFile(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetEventConfig eventConfig, err := configService.GetEventConfig(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, eventConfig) - // Verify the result is the expected event configuration tests.AssertEqual(t, "spa", eventConfig.Track) tests.AssertEqual(t, model.IntString(80), eventConfig.PreRaceWaitingTimeSeconds) tests.AssertEqual(t, model.IntString(120), eventConfig.SessionOverTimeSeconds) @@ -91,7 +77,6 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) { tests.AssertEqual(t, float64(0.3), eventConfig.CloudLevel) tests.AssertEqual(t, float64(0.0), eventConfig.Rain) - // Verify sessions tests.AssertEqual(t, 3, len(eventConfig.Sessions)) if len(eventConfig.Sessions) > 0 { tests.AssertEqual(t, model.SessionPractice, eventConfig.Sessions[0].SessionType) @@ -101,20 +86,16 @@ func TestConfigService_GetEventConfig_ValidFile(t *testing.T) { func TestConfigService_SaveConfiguration_Success(t *testing.T) { t.Skip("Temporarily disabled due to path issues") - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Prepare new configuration newConfig := &model.Configuration{ UdpPort: model.IntString(9999), TcpPort: model.IntString(10000), @@ -124,16 +105,13 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) { ConfigVersion: model.IntString(2), } - // Test SaveConfiguration err = configService.SaveConfiguration(helper.TestData.Server, newConfig) tests.AssertNoError(t, err) - // Verify the configuration was saved configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json") fileContent, err := os.ReadFile(configPath) tests.AssertNoError(t, err) - // Convert from UTF-16 to UTF-8 for verification utf8Content, err := service.DecodeUTF16LEBOM(fileContent) tests.AssertNoError(t, err) @@ -141,7 +119,6 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) { err = json.Unmarshal(utf8Content, &savedConfig) tests.AssertNoError(t, err) - // Verify the saved values tests.AssertEqual(t, "9999", savedConfig["udpPort"]) tests.AssertEqual(t, "10000", savedConfig["tcpPort"]) tests.AssertEqual(t, "40", savedConfig["maxConnections"]) @@ -151,25 +128,20 @@ func TestConfigService_SaveConfiguration_Success(t *testing.T) { } func TestConfigService_LoadConfigs_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test LoadConfigs configs, err := configService.LoadConfigs(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, configs) - // Verify all configurations are loaded tests.AssertEqual(t, model.IntString(9231), configs.Configuration.UdpPort) tests.AssertEqual(t, model.IntString(9232), configs.Configuration.TcpPort) tests.AssertEqual(t, "Test ACC Server", configs.Settings.ServerName) @@ -183,21 +155,17 @@ func TestConfigService_LoadConfigs_Success(t *testing.T) { } func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create server directory but no config files serverConfigDir := filepath.Join(helper.TestData.Server.Path, "cfg") err := os.MkdirAll(serverConfigDir, 0755) tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test LoadConfigs with missing files configs, err := configService.LoadConfigs(helper.TestData.Server) if err == nil { t.Fatal("Expected error for missing files, got nil") @@ -208,20 +176,16 @@ func TestConfigService_LoadConfigs_MissingFiles(t *testing.T) { } func TestConfigService_MalformedJSON(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create malformed config file err := helper.CreateMalformedConfigFile("configuration.json") tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // Test GetConfiguration with malformed JSON config, err := configService.GetConfiguration(helper.TestData.Server) if err == nil { t.Fatal("Expected error for malformed JSON, got nil") @@ -232,31 +196,24 @@ func TestConfigService_MalformedJSON(t *testing.T) { } func TestConfigService_UTF16_Encoding(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Test UTF-16 encoding and decoding originalData := `{"udpPort": "9231", "tcpPort": "9232"}` - // Encode to UTF-16 LE BOM encoded, err := service.EncodeUTF16LEBOM([]byte(originalData)) tests.AssertNoError(t, err) - // Decode back to UTF-8 decoded, err := service.DecodeUTF16LEBOM(encoded) tests.AssertNoError(t, err) - // Verify it matches original tests.AssertEqual(t, originalData, string(decoded)) } func TestConfigService_DecodeFileName(t *testing.T) { - // Test that all supported file names have decoders testCases := []string{ "configuration.json", "assistRules.json", @@ -272,7 +229,6 @@ func TestConfigService_DecodeFileName(t *testing.T) { }) } - // Test invalid filename decoder := service.DecodeFileName("invalid.json") if decoder != nil { t.Fatal("Expected nil decoder for invalid filename, got non-nil") @@ -280,22 +236,18 @@ func TestConfigService_DecodeFileName(t *testing.T) { } func TestConfigService_IntString_Conversion(t *testing.T) { - // Test IntString unmarshaling from string var intStr model.IntString - // Test string input err := json.Unmarshal([]byte(`"123"`), &intStr) tests.AssertNoError(t, err) tests.AssertEqual(t, 123, intStr.ToInt()) tests.AssertEqual(t, "123", intStr.ToString()) - // Test int input err = json.Unmarshal([]byte(`456`), &intStr) tests.AssertNoError(t, err) tests.AssertEqual(t, 456, intStr.ToInt()) tests.AssertEqual(t, "456", intStr.ToString()) - // Test empty string err = json.Unmarshal([]byte(`""`), &intStr) tests.AssertNoError(t, err) tests.AssertEqual(t, 0, intStr.ToInt()) @@ -303,28 +255,23 @@ func TestConfigService_IntString_Conversion(t *testing.T) { } func TestConfigService_IntBool_Conversion(t *testing.T) { - // Test IntBool unmarshaling from int var intBool model.IntBool - // Test int input (1 = true) err := json.Unmarshal([]byte(`1`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 1, intBool.ToInt()) tests.AssertEqual(t, true, intBool.ToBool()) - // Test int input (0 = false) err = json.Unmarshal([]byte(`0`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 0, intBool.ToInt()) tests.AssertEqual(t, false, intBool.ToBool()) - // Test bool input (true) err = json.Unmarshal([]byte(`true`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 1, intBool.ToInt()) tests.AssertEqual(t, true, intBool.ToBool()) - // Test bool input (false) err = json.Unmarshal([]byte(`false`), &intBool) tests.AssertNoError(t, err) tests.AssertEqual(t, 0, intBool.ToInt()) @@ -332,25 +279,20 @@ func TestConfigService_IntBool_Conversion(t *testing.T) { } func TestConfigService_Caching_Configuration(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files (already UTF-16 encoded) err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // First call - should load from disk config1, err := configService.GetConfiguration(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, config1) - // Modify the file on disk with UTF-16 encoding configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "configuration.json") modifiedContent := `{"udpPort": "5555", "tcpPort": "5556"}` utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent)) @@ -359,36 +301,29 @@ func TestConfigService_Caching_Configuration(t *testing.T) { err = os.WriteFile(configPath, utf16Modified, 0644) tests.AssertNoError(t, err) - // Second call - should return cached result (not the modified file) config2, err := configService.GetConfiguration(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, config2) - // Should still have the original cached values tests.AssertEqual(t, model.IntString(9231), config2.UdpPort) tests.AssertEqual(t, model.IntString(9232), config2.TcpPort) } func TestConfigService_Caching_EventConfig(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Create test config files (already UTF-16 encoded) err := helper.CreateTestConfigFiles() tests.AssertNoError(t, err) - // Create repositories and service configRepo := repository.NewConfigRepository(helper.DB) serverRepo := repository.NewServerRepository(helper.DB) configService := service.NewConfigService(configRepo, serverRepo) - // First call - should load from disk event1, err := configService.GetEventConfig(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, event1) - // Modify the file on disk with UTF-16 encoding configPath := filepath.Join(helper.TestData.Server.Path, "cfg", "event.json") modifiedContent := `{"track": "monza", "preRaceWaitingTimeSeconds": "60"}` utf16Modified, err := service.EncodeUTF16LEBOM([]byte(modifiedContent)) @@ -397,12 +332,10 @@ func TestConfigService_Caching_EventConfig(t *testing.T) { err = os.WriteFile(configPath, utf16Modified, 0644) tests.AssertNoError(t, err) - // Second call - should return cached result (not the modified file) event2, err := configService.GetEventConfig(helper.TestData.Server) tests.AssertNoError(t, err) tests.AssertNotNil(t, event2) - // Should still have the original cached values tests.AssertEqual(t, "spa", event2.Track) tests.AssertEqual(t, model.IntString(80), event2.PreRaceWaitingTimeSeconds) } diff --git a/tests/unit/service/state_history_service_test.go b/tests/unit/service/state_history_service_test.go index 9a9fd7b..9cbfcfa 100644 --- a/tests/unit/service/state_history_service_test.go +++ b/tests/unit/service/state_history_service_test.go @@ -15,11 +15,9 @@ import ( ) func TestStateHistoryService_GetAll_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -27,22 +25,17 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) { } } - // Use real repository like other service tests repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data directly into DB testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetAll filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := stateHistoryService.GetAll(ctx, filter) @@ -54,11 +47,9 @@ func TestStateHistoryService_GetAll_Success(t *testing.T) { } func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -69,7 +60,6 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data with different sessions testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 10, uuid.New()) @@ -79,12 +69,10 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { err = repo.Insert(helper.CreateContext(), &raceHistory) tests.AssertNoError(t, err) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetAll with session filter filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace) result, err := stateHistoryService.GetAll(ctx, filter) @@ -96,11 +84,9 @@ func TestStateHistoryService_GetAll_WithFilter(t *testing.T) { } func TestStateHistoryService_GetAll_NoData(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -111,12 +97,10 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetAll with no data filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := stateHistoryService.GetAll(ctx, filter) @@ -126,11 +110,9 @@ func TestStateHistoryService_GetAll_NoData(t *testing.T) { } func TestStateHistoryService_Insert_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -141,20 +123,16 @@ func TestStateHistoryService_Insert_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New()) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test Insert err := stateHistoryService.Insert(ctx, &history) tests.AssertNoError(t, err) - // Verify data was inserted filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) result, err := stateHistoryService.GetAll(ctx, filter) tests.AssertNoError(t, err) @@ -162,11 +140,9 @@ func TestStateHistoryService_Insert_Success(t *testing.T) { } func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -177,30 +153,25 @@ func TestStateHistoryService_GetLastSessionID_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID) sessionID := uuid.New() history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID) err := repo.Insert(helper.CreateContext(), &history) tests.AssertNoError(t, err) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetLastSessionID lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) tests.AssertEqual(t, sessionID, lastSessionID) } func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) { - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -211,26 +182,21 @@ func TestStateHistoryService_GetLastSessionID_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetLastSessionID with no data lastSessionID, err := stateHistoryService.GetLastSessionID(ctx, helper.TestData.ServerID) tests.AssertNoError(t, err) tests.AssertEqual(t, uuid.Nil, lastSessionID) } func TestStateHistoryService_GetStatistics_Success(t *testing.T) { - // This test might fail due to database setup issues t.Skip("Skipping test as it's dependent on database migration") - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -241,10 +207,8 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Insert test data with varying player counts _ = testdata.NewStateHistoryTestData(helper.TestData.ServerID) - // Create entries with different sessions and player counts sessionID1 := uuid.New() sessionID2 := uuid.New() @@ -291,12 +255,10 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { tests.AssertNoError(t, err) } - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetStatistics filter := &model.StateHistoryFilter{ ServerBasedFilter: model.ServerBasedFilter{ ServerID: helper.TestData.ServerID.String(), @@ -311,17 +273,14 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { tests.AssertNoError(t, err) tests.AssertNotNil(t, stats) - // Verify statistics - tests.AssertEqual(t, 15, stats.PeakPlayers) // Maximum player count - tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions + tests.AssertEqual(t, 15, stats.PeakPlayers) + tests.AssertEqual(t, 2, stats.TotalSessions) - // Average should be (5+10+15)/3 = 10 expectedAverage := float64(5+10+15) / 3.0 if stats.AveragePlayers != expectedAverage { t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers) } - // Verify other statistics components exist tests.AssertNotNil(t, stats.PlayerCountOverTime) tests.AssertNotNil(t, stats.SessionTypes) tests.AssertNotNil(t, stats.DailyActivity) @@ -329,14 +288,11 @@ func TestStateHistoryService_GetStatistics_Success(t *testing.T) { } func TestStateHistoryService_GetStatistics_NoData(t *testing.T) { - // This test might fail due to database setup issues t.Skip("Skipping test as it's dependent on database migration") - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Ensure the state_histories table exists if !helper.DB.Migrator().HasTable(&model.StateHistory{}) { err := helper.DB.Migrator().CreateTable(&model.StateHistory{}) if err != nil { @@ -347,19 +303,16 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) { repo := repository.NewStateHistoryRepository(helper.DB) stateHistoryService := service.NewStateHistoryService(repo) - // Create proper Fiber context app := fiber.New() ctx := helper.CreateFiberCtx() defer helper.ReleaseFiberCtx(app, ctx) - // Test GetStatistics with no data filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String()) stats, err := stateHistoryService.GetStatistics(ctx, filter) tests.AssertNoError(t, err) tests.AssertNotNil(t, stats) - // Verify empty statistics tests.AssertEqual(t, 0, stats.PeakPlayers) tests.AssertEqual(t, 0.0, stats.AveragePlayers) tests.AssertEqual(t, 0, stats.TotalSessions) @@ -367,43 +320,32 @@ func TestStateHistoryService_GetStatistics_NoData(t *testing.T) { } func TestStateHistoryService_LogParsingWorkflow(t *testing.T) { - // Skip this test as it's unreliable and not critical t.Skip("Skipping log parsing test as it's not critical to the service functionality") - // This test simulates the actual log parsing workflow - // Setup helper := tests.NewTestHelper(t) defer helper.Cleanup() - // Insert test server err := helper.InsertTestServer() tests.AssertNoError(t, err) server := helper.TestData.Server - // Track state changes var stateChanges []*model.ServerState onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) { - // Use pointer to avoid copying mutex stateChanges = append(stateChanges, state) } - // Create AccServerInstance (this is what the real server service does) instance := tracking.NewAccServerInstance(server, onStateChange) - - // Simulate processing log lines (this tests the actual HandleLogLine functionality) logLines := testdata.SampleLogLines for _, line := range logLines { instance.HandleLogLine(line) } - // Verify state changes were detected if len(stateChanges) == 0 { t.Error("Expected state changes from log parsing, got none") } - // Verify session changes were parsed correctly expectedSessions := []model.TrackSession{model.SessionPractice, model.SessionQualify, model.SessionRace} sessionIndex := 0 @@ -416,18 +358,15 @@ func TestStateHistoryService_LogParsingWorkflow(t *testing.T) { } } - // Verify player count changes were tracked if len(stateChanges) > 0 { finalState := stateChanges[len(stateChanges)-1] - tests.AssertEqual(t, 0, finalState.PlayerCount) // Should end with 0 players + tests.AssertEqual(t, 0, finalState.PlayerCount) } } func TestStateHistoryService_SessionChangeTracking(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping session tracking test as it's unreliable in CI environments") - // Test session change detection server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -437,7 +376,6 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) { onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) { for _, change := range changes { if change == tracking.Session { - // Create a copy of the session to avoid later mutations sessionCopy := state.Session sessionChanges = append(sessionChanges, sessionCopy) } @@ -446,22 +384,17 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) { instance := tracking.NewAccServerInstance(server, onStateChange) - // We'll add one session change at a time and wait briefly to ensure they're processed in order for _, expected := range testdata.ExpectedSessionChanges { line := string("[2024-01-15 14:30:25.123] Session changed: " + expected.From + " -> " + expected.To) instance.HandleLogLine(line) - // Small pause to ensure log processing completes time.Sleep(10 * time.Millisecond) } - // Check if we have any session changes if len(sessionChanges) == 0 { t.Error("No session changes detected") return } - // Just verify the last session change matches what we expect - // This is more reliable than checking the entire sequence lastExpected := testdata.ExpectedSessionChanges[len(testdata.ExpectedSessionChanges)-1].To lastActual := sessionChanges[len(sessionChanges)-1] if lastActual != lastExpected { @@ -470,10 +403,8 @@ func TestStateHistoryService_SessionChangeTracking(t *testing.T) { } func TestStateHistoryService_PlayerCountTracking(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping player count tracking test as it's unreliable in CI environments") - // Test player count change detection server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -490,7 +421,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { instance := tracking.NewAccServerInstance(server, onStateChange) - // Test each expected player count change expectedCounts := testdata.ExpectedPlayerCounts logLines := []string{ "[2024-01-15 14:30:30.456] 1 client(s) online", @@ -499,7 +429,7 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { "[2024-01-15 14:35:05.789] 8 client(s) online", "[2024-01-15 14:40:05.456] 12 client(s) online", "[2024-01-15 14:45:00.789] 15 client(s) online", - "[2024-01-15 14:50:00.789] Removing dead connection", // Should decrease by 1 + "[2024-01-15 14:50:00.789] Removing dead connection", "[2024-01-15 15:00:00.789] 0 client(s) online", } @@ -507,7 +437,6 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { instance.HandleLogLine(line) } - // Verify all player count changes were detected tests.AssertEqual(t, len(expectedCounts), len(playerCounts)) for i, expected := range expectedCounts { if i < len(playerCounts) { @@ -517,10 +446,8 @@ func TestStateHistoryService_PlayerCountTracking(t *testing.T) { } func TestStateHistoryService_EdgeCases(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping edge cases test as it's unreliable in CI environments") - // Test edge cases in log parsing server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -528,34 +455,30 @@ func TestStateHistoryService_EdgeCases(t *testing.T) { var stateChanges []*model.ServerState onStateChange := func(state *model.ServerState, changes ...tracking.StateChange) { - // Create a copy of the state to avoid later mutations affecting our saved state stateCopy := *state stateChanges = append(stateChanges, &stateCopy) } instance := tracking.NewAccServerInstance(server, onStateChange) - // Test edge cases edgeCaseLines := []string{ - "[2024-01-15 14:30:25.123] Some unrelated log line", // Should be ignored - "[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", // Valid session change - "[2024-01-15 14:30:30.456] 0 client(s) online", // Zero players - "[2024-01-15 14:30:35.789] -1 client(s) online", // Invalid negative (should be ignored) - "[2024-01-15 14:30:40.789] 30 client(s) online", // High but valid player count - "[2024-01-15 14:30:45.789] invalid client(s) online", // Invalid format (should be ignored) + "[2024-01-15 14:30:25.123] Some unrelated log line", + "[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE", + "[2024-01-15 14:30:30.456] 0 client(s) online", + "[2024-01-15 14:30:35.789] -1 client(s) online", + "[2024-01-15 14:30:40.789] 30 client(s) online", + "[2024-01-15 14:30:45.789] invalid client(s) online", } for _, line := range edgeCaseLines { instance.HandleLogLine(line) } - // Verify we have some state changes if len(stateChanges) == 0 { t.Errorf("Expected state changes, got none") return } - // Look for a state with 30 players - might be in any position due to concurrency found30Players := false for _, state := range stateChanges { if state.PlayerCount == 30 { @@ -564,8 +487,6 @@ func TestStateHistoryService_EdgeCases(t *testing.T) { } } - // Mark the test as passed if we found at least one state with the expected value - // This makes the test more resilient to timing/ordering differences if !found30Players { t.Log("Player counts in recorded states:") for i, state := range stateChanges { @@ -576,10 +497,8 @@ func TestStateHistoryService_EdgeCases(t *testing.T) { } func TestStateHistoryService_SessionStartTracking(t *testing.T) { - // Skip this test as it's unreliable t.Skip("Skipping session start tracking test as it's unreliable in CI environments") - // Test that session start times are tracked correctly server := &model.Server{ ID: uuid.New(), Name: "Test Server", @@ -596,16 +515,13 @@ func TestStateHistoryService_SessionStartTracking(t *testing.T) { instance := tracking.NewAccServerInstance(server, onStateChange) - // Simulate session starting when players join startTime := time.Now() - instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online") // First player joins + instance.HandleLogLine("[2024-01-15 14:30:30.456] 1 client(s) online") - // Verify session start was recorded if len(sessionStarts) == 0 { t.Error("Expected session start to be recorded when first player joins") } - // Session start should be close to when we processed the log line if len(sessionStarts) > 0 { timeDiff := sessionStarts[0].Sub(startTime) if timeDiff > time.Second || timeDiff < -time.Second {