19 Commits

Author SHA1 Message Date
Fran Jurmanović
4004d83411 add step list for server creation
All checks were successful
Release and Deploy / build (push) Successful in 9m5s
Release and Deploy / deploy (push) Successful in 26s
2025-09-18 22:24:51 +02:00
Fran Jurmanović
901dbe697e Use sockets for server creation progress 2025-09-18 01:06:58 +02:00
Fran Jurmanović
760412d7db update state history session type
All checks were successful
Release and Deploy / build (push) Successful in 2m25s
Release and Deploy / deploy (push) Successful in 26s
2025-09-15 19:11:25 +02:00
Fran Jurmanović
4ab94de529 resolve test failing
All checks were successful
Release and Deploy / build (push) Successful in 2m12s
Release and Deploy / deploy (push) Successful in 27s
2025-09-14 22:05:25 +02:00
Fran Jurmanović
b3f89593fb replace id type
Some checks failed
Release and Deploy / build (push) Failing after 2m7s
Release and Deploy / deploy (push) Has been skipped
2025-09-14 17:08:50 +02:00
Fran Jurmanović
2a863c51e9 update state history query
All checks were successful
Release and Deploy / build (push) Successful in 3m18s
Release and Deploy / deploy (push) Successful in 27s
2025-09-13 14:41:52 +02:00
Fran Jurmanović
a70d923a6a revert to powershell executable 2025-08-17 16:53:44 +02:00
Fran Jurmanović
f660511b63 change steamCMD executor
All checks were successful
Release and Deploy / build (push) Successful in 2m1s
Release and Deploy / deploy (push) Successful in 23s
2025-08-17 16:42:10 +02:00
Fran Jurmanović
044af60699 steam-crypt app and fix the interactive executor
All checks were successful
Release and Deploy / build (push) Successful in 2m2s
Release and Deploy / deploy (push) Successful in 22s
2025-08-17 16:26:28 +02:00
Fran Jurmanović
384036bcdd remove blocker pattern
All checks were successful
Release and Deploy / build (push) Successful in 2m5s
Release and Deploy / deploy (push) Successful in 26s
2025-08-17 15:53:55 +02:00
Fran Jurmanović
ef300d233b fix wrong userID from context
All checks were successful
Release and Deploy / build (push) Successful in 3m26s
Release and Deploy / deploy (push) Successful in 24s
2025-08-17 13:12:36 +02:00
Fran Jurmanović
edad65d6a9 generate open token using normal token
All checks were successful
Release and Deploy / build (push) Successful in 3m1s
Release and Deploy / deploy (push) Successful in 23s
2025-08-17 12:46:37 +02:00
Fran Jurmanović
486c972bba open token authentication
All checks were successful
Release and Deploy / build (push) Successful in 3m51s
Release and Deploy / deploy (push) Successful in 28s
2025-08-17 12:15:39 +02:00
Fran Jurmanović
aab5d2ad61 steam 2fa for polling and security
All checks were successful
Release and Deploy / build (push) Successful in 6m8s
Release and Deploy / deploy (push) Successful in 27s
2025-08-16 16:43:54 +02:00
Fran Jurmanović
1683d5c2f1 generate swagger docs
All checks were successful
Release and Deploy / build (push) Successful in 1m50s
Release and Deploy / deploy (push) Successful in 21s
2025-08-05 17:09:05 +02:00
Fran Jurmanović
87d4af0bec update host and schemes swagger
All checks were successful
Release and Deploy / build (push) Successful in 1m51s
Release and Deploy / deploy (push) Successful in 21s
2025-08-05 16:51:33 +02:00
Fran Jurmanović
35449a090d update swagger host
All checks were successful
Release and Deploy / build (push) Successful in 1m49s
Release and Deploy / deploy (push) Successful in 21s
2025-08-05 14:52:02 +02:00
Fran Jurmanović
5324a41e05 update swagger base path
All checks were successful
Release and Deploy / build (push) Successful in 1m48s
Release and Deploy / deploy (push) Successful in 22s
2025-08-05 14:39:44 +02:00
Fran Jurmanović
ac61ba5223 update swagger docs
All checks were successful
Release and Deploy / build (push) Successful in 2m16s
Release and Deploy / deploy (push) Successful in 25s
2025-08-05 14:32:37 +02:00
102 changed files with 5896 additions and 2933 deletions

View File

@@ -17,7 +17,6 @@ import (
func main() {
configs.Init()
jwt.Init()
// Initialize new logging system
if err := logging.InitializeLogging(); err != nil {
fmt.Printf("Failed to initialize logging system: %v\n", err)
@@ -37,6 +36,8 @@ func main() {
logging.InfoStartup("APPLICATION", "ACC Server Manager starting up")
di := dig.New()
di.Provide(func() *jwt.JWTHandler { return jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) })
di.Provide(func() *jwt.OpenJWTHandler { return jwt.NewOpenJWTHandler(os.Getenv("JWT_SECRET_OPEN")) })
cache.Start(di)
db.Start(di)
server.Start(di)

View File

@@ -5,14 +5,14 @@
// @description API for managing Assetto Corsa Competizione dedicated servers
//
// @contact.name ACC Server Manager Support
// @contact.url https://github.com/yourusername/acc-server-manager
// @contact.url https://github.com/FJurmanovic/acc-server-manager
//
// @license.name MIT
// @license.url https://opensource.org/licenses/MIT
//
// @host localhost:3000
// @BasePath /api/v1
// @schemes http https
// @host acc-api.jurmanovic.com
// @BasePath /v1
// @schemes https
//
// @securityDefinitions.apikey BearerAuth
// @in header

93
cmd/steam-crypt/main.go Normal file
View File

@@ -0,0 +1,93 @@
package main
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/configs"
"flag"
"fmt"
"os"
)
func main() {
var (
encrypt = flag.Bool("encrypt", false, "Encrypt a password")
decrypt = flag.Bool("decrypt", false, "Decrypt a password")
password = flag.String("password", "", "Password to encrypt/decrypt")
help = flag.Bool("help", false, "Show help")
)
flag.Parse()
if *help || (!*encrypt && !*decrypt) {
showHelp()
return
}
if *encrypt && *decrypt {
fmt.Fprintf(os.Stderr, "Error: Cannot specify both -encrypt and -decrypt\n")
os.Exit(1)
}
if *password == "" {
fmt.Fprintf(os.Stderr, "Error: Password is required\n")
showHelp()
os.Exit(1)
}
// Initialize configs to load encryption key
configs.Init()
if *encrypt {
encrypted, err := model.EncryptPassword(*password)
if err != nil {
fmt.Fprintf(os.Stderr, "Error encrypting password: %v\n", err)
os.Exit(1)
}
fmt.Println(encrypted)
}
if *decrypt {
decrypted, err := model.DecryptPassword(*password)
if err != nil {
fmt.Fprintf(os.Stderr, "Error decrypting password: %v\n", err)
os.Exit(1)
}
fmt.Println(decrypted)
}
}
func showHelp() {
fmt.Println("Steam Credentials Encryption/Decryption Utility")
fmt.Println()
fmt.Println("This utility encrypts and decrypts Steam credentials using the same")
fmt.Println("AES-256-GCM encryption used by the ACC Server Manager application.")
fmt.Println()
fmt.Println("Usage:")
fmt.Println(" steam-crypt -encrypt -password \"your_password\"")
fmt.Println(" steam-crypt -decrypt -password \"encrypted_string\"")
fmt.Println()
fmt.Println("Options:")
fmt.Println(" -encrypt Encrypt the provided password")
fmt.Println(" -decrypt Decrypt the provided encrypted string")
fmt.Println(" -password The password to encrypt or encrypted string to decrypt")
fmt.Println(" -help Show this help message")
fmt.Println()
fmt.Println("Environment Variables Required:")
fmt.Println(" ENCRYPTION_KEY - 32-byte encryption key (same as main application)")
fmt.Println(" APP_SECRET - Application secret (required by configs)")
fmt.Println(" APP_SECRET_CODE - Application secret code (required by configs)")
fmt.Println(" ACCESS_KEY - Access key (required by configs)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" # Encrypt a password")
fmt.Println(" steam-crypt -encrypt -password \"mysteampassword\"")
fmt.Println()
fmt.Println(" # Decrypt an encrypted password")
fmt.Println(" steam-crypt -decrypt -password \"base64encryptedstring\"")
fmt.Println()
fmt.Println("Security Notes:")
fmt.Println(" - The encryption key must be exactly 32 bytes for AES-256")
fmt.Println(" - Uses AES-256-GCM for authenticated encryption")
fmt.Println(" - Each encryption includes a unique nonce for security")
fmt.Println(" - Passwords are validated for length and basic security")
}

View File

@@ -0,0 +1,243 @@
# Steam 2FA Implementation Documentation
## Overview
This document describes the implementation of Steam Two-Factor Authentication (2FA) support for the ACC Server Manager. When SteamCMD requires 2FA confirmation during server installation or updates, the system now signals the frontend and waits for user confirmation before proceeding.
## Architecture
The 2FA implementation consists of several interconnected components:
### Backend Components
1. **Steam2FAManager** (`local/model/steam_2fa.go`)
- Thread-safe management of 2FA requests
- Request lifecycle tracking (pending → complete/error)
- Channel-based waiting mechanism for synchronization
2. **InteractiveCommandExecutor** (`local/utl/command/interactive_executor.go`)
- Monitors SteamCMD output for 2FA prompts
- Creates 2FA requests when prompts are detected
- Waits for user confirmation before proceeding
3. **Steam2FAController** (`local/controller/steam_2fa.go`)
- REST API endpoints for 2FA management
- Handles frontend requests to complete/cancel 2FA
4. **Updated SteamService** (`local/service/steam_service.go`)
- Uses InteractiveCommandExecutor for SteamCMD operations
- Passes server context to 2FA requests
### Frontend Components
1. **Steam2FA Store** (`src/stores/steam2fa.ts`)
- Svelte store for managing 2FA state
- Automatic polling for pending requests
- API communication methods
2. **Steam2FANotification Component** (`src/components/Steam2FANotification.svelte`)
- Modal UI for 2FA confirmation
- Automatic display when requests are pending
- User interaction handling
3. **Type Definitions** (`src/models/steam2fa.ts`)
- TypeScript interfaces for 2FA data structures
## API Endpoints
### GET /v1/steam2fa/pending
Returns all pending 2FA requests.
**Response:**
```json
[
{
"id": "uuid-string",
"status": "pending",
"message": "Steam Guard prompt message",
"requestTime": "2024-01-01T12:00:00Z",
"serverId": "server-uuid"
}
]
```
### GET /v1/steam2fa/{id}
Returns a specific 2FA request by ID.
### POST /v1/steam2fa/{id}/complete
Marks a 2FA request as completed, allowing SteamCMD to proceed.
### POST /v1/steam2fa/{id}/cancel
Cancels a 2FA request, causing the SteamCMD operation to fail.
## Flow Diagram
```
SteamCMD Operation
InteractiveCommandExecutor monitors output
2FA prompt detected
Steam2FARequest created
Frontend polls and detects request
Modal appears for user
User confirms in Steam Mobile App
User clicks "I've Confirmed"
API call to complete request
SteamCMD operation continues
```
## Configuration
### Backend Configuration
The system uses existing configuration patterns. No additional environment variables are required.
### Frontend Configuration
The API base URL is automatically configured as `/v1` to match the backend prefix.
Polling interval is set to 5 seconds by default and can be modified in `steam2fa.ts`:
```typescript
const POLLING_INTERVAL = 5000; // milliseconds
```
## Security Considerations
1. **Authentication Required**: All 2FA endpoints require user authentication
2. **Permission-Based Access**: Uses existing `ServerView` and `ServerUpdate` permissions
3. **Request Cleanup**: Automatic cleanup of old requests (30 minutes) prevents memory leaks
4. **No Sensitive Data**: No Steam credentials are exposed through the 2FA system
## Error Handling
### Backend Error Handling
- Timeouts after 5 minutes if no user response
- Proper error propagation to calling services
- Comprehensive logging for debugging
### Frontend Error Handling
- Network error handling with user feedback
- Automatic retry mechanisms
- Graceful degradation when API is unavailable
## Usage Instructions
### For Developers
1. **Adding New 2FA Prompts**: Extend the `is2FAPrompt` function in `interactive_executor.go`
2. **Customizing Timeouts**: Modify the timeout duration in `handle2FAPrompt`
3. **UI Customization**: Modify the `Steam2FANotification.svelte` component
### For Users
1. When creating or updating a server, watch for the 2FA notification
2. Check your Steam Mobile App when prompted
3. Confirm the login request in the Steam app
4. Click "I've Confirmed" in the web interface
5. The server operation will continue automatically
## Monitoring and Debugging
### Backend Logs
The system logs important events:
- 2FA prompt detection
- Request creation and completion
- Timeout events
- Error conditions
Search for log entries containing:
- `2FA prompt detected`
- `Created 2FA request`
- `2FA completed successfully`
- `2FA completion failed`
### Frontend Debugging
The Steam2FA store provides debugging information:
- `$steam2fa.error` - Current error state
- `$steam2fa.isLoading` - Loading state
- `$steam2fa.lastChecked` - Last polling timestamp
## Performance Considerations
1. **Polling Frequency**: 5-second polling provides good responsiveness without excessive load
2. **Request Cleanup**: Automatic cleanup prevents memory accumulation
3. **Efficient UI Updates**: Reactive Svelte stores minimize unnecessary re-renders
## Limitations
1. **Single User Sessions**: Currently designed for single-user scenarios
2. **Steam Mobile App Required**: Users must have Steam Mobile App installed
3. **Manual Confirmation**: No automatic 2FA code input support
## Future Enhancements
1. **WebSocket Support**: Real-time communication instead of polling
2. **Multiple User Support**: Handle multiple simultaneous 2FA requests
3. **Enhanced Prompt Detection**: More sophisticated Steam output parsing
4. **Notification System**: Browser notifications for 2FA requests
## Testing
### Manual Testing
1. Create a new server to trigger SteamCMD
2. Ensure Steam account has 2FA enabled
3. Verify modal appears when 2FA is required
4. Test both "confirm" and "cancel" workflows
### Automated Testing
The system includes comprehensive error handling but manual testing is recommended for 2FA workflows due to the interactive nature.
## Troubleshooting
### Common Issues
1. **Modal doesn't appear**
- Check browser console for errors
- Verify API connectivity
- Ensure user has proper permissions
2. **SteamCMD hangs**
- Check if 2FA request was created (backend logs)
- Verify Steam Mobile App connectivity
- Check for timeout errors
3. **API errors**
- Verify user authentication
- Check server permissions
- Review backend error logs
### Debug Commands
```bash
# Check backend logs for 2FA events
grep -i "2fa" logs/app.log
# Monitor API requests
tail -f logs/app.log | grep "steam2fa"
```
## Version History
- **v1.0.0**: Initial implementation with polling-based frontend and REST API
- Added comprehensive error handling and logging
- Implemented automatic request cleanup
- Added responsive UI components
## Contributing
When contributing to the 2FA system:
1. Follow existing error handling patterns
2. Add comprehensive logging for new features
3. Update this documentation for any API changes
4. Test with actual Steam 2FA scenarios
5. Consider security implications of any changes

3
go.mod
View File

@@ -21,10 +21,12 @@ require (
require (
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/fasthttp/websocket v1.5.3 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/jsonreference v0.21.0 // indirect
github.com/go-openapi/spec v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/gofiber/websocket/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
@@ -35,6 +37,7 @@ require (
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
github.com/swaggo/files/v2 v2.0.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect

6
go.sum
View File

@@ -4,6 +4,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQtdek=
github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs=
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ=
@@ -16,6 +18,8 @@ github.com/gofiber/fiber/v2 v2.52.8 h1:xl4jJQ0BV5EJTA2aWiKw/VddRpHrKeZLF0QPUxqn0
github.com/gofiber/fiber/v2 v2.52.8/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/gofiber/swagger v1.1.0 h1:ff3rg1fB+Rp5JN/N8jfxTiZtMKe/9tB9QDc79fPiJKQ=
github.com/gofiber/swagger v1.1.0/go.mod h1:pRZL0Np35sd+lTODTE5The0G+TMHfNY+oC4hM2/i5m8=
github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w=
github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -53,6 +57,8 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw=

View File

@@ -31,6 +31,7 @@ func Init(di *dig.Container, app *fiber.App) {
StateHistory: serverIdGroup.Group("/state-history"),
Membership: groups.Group("/membership"),
System: groups.Group("/system"),
WebSocket: groups.Group("/ws"),
}
accessKeyMiddleware := middleware.NewAccessKeyMiddleware()

View File

@@ -58,7 +58,7 @@ func NewConfigController(as *service.ConfigService, routeGroups *common.RouteGro
// @Failure 404 {object} error_handler.ErrorResponse "Server or config file not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server/{id}/config/{file} [put]
// @Router /server/{id}/config/{file} [put]
func (ac *ConfigController) UpdateConfig(c *fiber.Ctx) error {
restart := c.QueryBool("restart")
serverID := c.Params("id")
@@ -106,7 +106,7 @@ func (ac *ConfigController) UpdateConfig(c *fiber.Ctx) error {
// @Failure 404 {object} error_handler.ErrorResponse "Server or config file not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server/{id}/config/{file} [get]
// @Router /server/{id}/config/{file} [get]
func (ac *ConfigController) GetConfig(c *fiber.Ctx) error {
Model, err := ac.service.GetConfig(c)
if err != nil {
@@ -130,7 +130,7 @@ func (ac *ConfigController) GetConfig(c *fiber.Ctx) error {
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server/{id}/config [get]
// @Router /server/{id}/config [get]
func (ac *ConfigController) GetConfigs(c *fiber.Ctx) error {
Model, err := ac.service.GetConfigs(c)
if err != nil {

View File

@@ -54,4 +54,9 @@ func InitializeControllers(c *dig.Container) {
if err != nil {
logging.Panic("unable to initialize membership controller")
}
err = c.Invoke(NewWebSocketController)
if err != nil {
logging.Panic("unable to initialize websocket controller")
}
}

View File

@@ -46,7 +46,7 @@ func NewLookupController(as *service.LookupService, routeGroups *common.RouteGro
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/lookup/tracks [get]
// @Router /lookup/tracks [get]
func (ac *LookupController) GetTracks(c *fiber.Ctx) error {
result, err := ac.service.GetTracks(c)
if err != nil {
@@ -66,7 +66,7 @@ func (ac *LookupController) GetTracks(c *fiber.Ctx) error {
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/lookup/car-models [get]
// @Router /lookup/car-models [get]
func (ac *LookupController) GetCarModels(c *fiber.Ctx) error {
result, err := ac.service.GetCarModels(c)
if err != nil {
@@ -86,7 +86,7 @@ func (ac *LookupController) GetCarModels(c *fiber.Ctx) error {
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/lookup/driver-categories [get]
// @Router /lookup/driver-categories [get]
func (ac *LookupController) GetDriverCategories(c *fiber.Ctx) error {
result, err := ac.service.GetDriverCategories(c)
if err != nil {
@@ -106,7 +106,7 @@ func (ac *LookupController) GetDriverCategories(c *fiber.Ctx) error {
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/lookup/cup-categories [get]
// @Router /lookup/cup-categories [get]
func (ac *LookupController) GetCupCategories(c *fiber.Ctx) error {
result, err := ac.service.GetCupCategories(c)
if err != nil {
@@ -126,7 +126,7 @@ func (ac *LookupController) GetCupCategories(c *fiber.Ctx) error {
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/lookup/session-types [get]
// @Router /lookup/session-types [get]
func (ac *LookupController) GetSessionTypes(c *fiber.Ctx) error {
result, err := ac.service.GetSessionTypes(c)
if err != nil {

View File

@@ -34,6 +34,7 @@ func NewMembershipController(service *service.MembershipService, auth *middlewar
}
routeGroups.Auth.Post("/login", mc.Login)
routeGroups.Auth.Post("/open-token", mc.auth.Authenticate, mc.GenerateOpenToken)
usersGroup := routeGroups.Membership
usersGroup.Use(mc.auth.Authenticate)
@@ -82,6 +83,26 @@ func (c *MembershipController) Login(ctx *fiber.Ctx) error {
return ctx.JSON(fiber.Map{"token": token})
}
// GenerateOpenToken generates an open token for a user.
// @Summary Generate an open token
// @Description Generate an open token for a user
// @Tags Authentication
// @Accept json
// @Produce json
// @Success 200 {object} object{token=string} "JWT token"
// @Failure 400 {object} error_handler.ErrorResponse "Invalid request body"
// @Failure 401 {object} error_handler.ErrorResponse "Invalid credentials"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Router /auth/open-token [post]
func (c *MembershipController) GenerateOpenToken(ctx *fiber.Ctx) error {
token, err := c.service.GenerateOpenToken(ctx.UserContext(), ctx.Locals("userID").(string))
if err != nil {
return c.errorHandler.HandleAuthError(ctx, err)
}
return ctx.JSON(fiber.Map{"token": token})
}
// CreateUser creates a new user.
// @Summary Create a new user
// @Description Create a new user account with specified role
@@ -139,6 +160,18 @@ func (mc *MembershipController) ListUsers(c *fiber.Ctx) error {
}
// GetUser gets a single user by ID.
// @Summary Get user by ID
// @Description Get detailed information about a specific user
// @Tags User Management
// @Accept json
// @Produce json
// @Param id path string true "User ID (UUID format)"
// @Success 200 {object} model.User "User details"
// @Failure 400 {object} error_handler.ErrorResponse "Invalid user ID format"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
// @Security BearerAuth
// @Router /membership/{id} [get]
func (mc *MembershipController) GetUser(c *fiber.Ctx) error {
id, err := uuid.Parse(c.Params("id"))
if err != nil {
@@ -154,6 +187,16 @@ func (mc *MembershipController) GetUser(c *fiber.Ctx) error {
}
// GetMe returns the currently authenticated user's details.
// @Summary Get current user details
// @Description Get details of the currently authenticated user
// @Tags Authentication
// @Accept json
// @Produce json
// @Success 200 {object} model.User "Current user details"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
// @Security BearerAuth
// @Router /auth/me [get]
func (mc *MembershipController) GetMe(c *fiber.Ctx) error {
userID, ok := c.Locals("userID").(string)
if !ok || userID == "" {
@@ -172,6 +215,19 @@ func (mc *MembershipController) GetMe(c *fiber.Ctx) error {
}
// DeleteUser deletes a user.
// @Summary Delete user
// @Description Delete a specific user by ID
// @Tags User Management
// @Accept json
// @Produce json
// @Param id path string true "User ID (UUID format)"
// @Success 204 "User successfully deleted"
// @Failure 400 {object} error_handler.ErrorResponse "Invalid user ID format"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
// @Security BearerAuth
// @Router /membership/{id} [delete]
func (mc *MembershipController) DeleteUser(c *fiber.Ctx) error {
id, err := uuid.Parse(c.Params("id"))
if err != nil {
@@ -187,6 +243,20 @@ func (mc *MembershipController) DeleteUser(c *fiber.Ctx) error {
}
// UpdateUser updates a user.
// @Summary Update user
// @Description Update user details by ID
// @Tags User Management
// @Accept json
// @Produce json
// @Param id path string true "User ID (UUID format)"
// @Param user body service.UpdateUserRequest true "Updated user details"
// @Success 200 {object} model.User "Updated user details"
// @Failure 400 {object} error_handler.ErrorResponse "Invalid request body or ID format"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
// @Security BearerAuth
// @Router /membership/{id} [put]
func (mc *MembershipController) UpdateUser(c *fiber.Ctx) error {
id, err := uuid.Parse(c.Params("id"))
if err != nil {
@@ -207,6 +277,17 @@ func (mc *MembershipController) UpdateUser(c *fiber.Ctx) error {
}
// GetRoles returns all available roles.
// @Summary Get all roles
// @Description Get a list of all available user roles
// @Tags User Management
// @Accept json
// @Produce json
// @Success 200 {array} model.Role "List of roles"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /membership/roles [get]
func (mc *MembershipController) GetRoles(c *fiber.Ctx) error {
roles, err := mc.service.GetAllRoles(c.UserContext())
if err != nil {

View File

@@ -29,7 +29,6 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro
serverRoutes.Get("/", auth.HasPermission(model.ServerView), ac.GetAll)
serverRoutes.Get("/:id", auth.HasPermission(model.ServerView), ac.GetById)
serverRoutes.Post("/", auth.HasPermission(model.ServerCreate), ac.CreateServer)
serverRoutes.Put("/:id", auth.HasPermission(model.ServerUpdate), ac.UpdateServer)
serverRoutes.Delete("/:id", auth.HasPermission(model.ServerDelete), ac.DeleteServer)
apiServerRoutes := routeGroups.Api.Group("/server")
@@ -50,7 +49,7 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /server [get]
// @Router /api/server [get]
func (ac *ServerController) GetAllApi(c *fiber.Ctx) error {
var filter model.ServerFilter
if err := common.ParseQueryFilter(c, &filter); err != nil {
@@ -79,7 +78,7 @@ func (ac *ServerController) GetAllApi(c *fiber.Ctx) error {
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server [get]
// @Router /server [get]
func (ac *ServerController) GetAll(c *fiber.Ctx) error {
var filter model.ServerFilter
if err := common.ParseQueryFilter(c, &filter); err != nil {
@@ -105,7 +104,7 @@ func (ac *ServerController) GetAll(c *fiber.Ctx) error {
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server/{id} [get]
// @Router /server/{id} [get]
func (ac *ServerController) GetById(c *fiber.Ctx) error {
serverIDStr := c.Params("id")
serverID, err := uuid.Parse(serverIDStr)
@@ -133,55 +132,40 @@ func (ac *ServerController) GetById(c *fiber.Ctx) error {
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server [post]
// @Router /server [post]
func (ac *ServerController) CreateServer(c *fiber.Ctx) error {
server := new(model.Server)
if err := c.BodyParser(server); err != nil {
return ac.errorHandler.HandleParsingError(c, err)
}
ac.service.GenerateServerPath(server)
if err := ac.service.CreateServer(c, server); err != nil {
server.GenerateUUID()
// Use async server creation to avoid blocking other requests
if err := ac.service.CreateServerAsync(c, server); err != nil {
return ac.errorHandler.HandleServiceError(c, err)
}
// Return immediately with server details
// The actual creation will happen in the background with WebSocket updates
return c.JSON(server)
}
// UpdateServer updates an existing server
// @Summary Update an ACC server
// @Description Update configuration for an existing ACC server
// DeleteServer deletes an existing server
// @Summary Delete an ACC server
// @Description Delete an existing ACC server
// @Tags Server
// @Accept json
// @Produce json
// @Param id path string true "Server ID (UUID format)"
// @Param server body model.Server true "Updated server configuration"
// @Success 200 {object} object "Updated server details"
// @Success 200 {object} object "Deleted server details"
// @Failure 400 {object} error_handler.ErrorResponse "Invalid server data or ID"
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/server/{id} [put]
func (ac *ServerController) UpdateServer(c *fiber.Ctx) error {
serverIDStr := c.Params("id")
serverID, err := uuid.Parse(serverIDStr)
if err != nil {
return ac.errorHandler.HandleUUIDError(c, "server ID")
}
server := new(model.Server)
if err := c.BodyParser(server); err != nil {
return ac.errorHandler.HandleParsingError(c, err)
}
server.ID = serverID
if err := ac.service.UpdateServer(c, server); err != nil {
return ac.errorHandler.HandleServiceError(c, err)
}
return c.JSON(server)
}
// DeleteServer deletes a server
// @Router /server/{id} [delete]
func (ac *ServerController) DeleteServer(c *fiber.Ctx) error {
serverIDStr := c.Params("id")
serverID, err := uuid.Parse(serverIDStr)

View File

@@ -51,7 +51,7 @@ func NewServiceControlController(as *service.ServiceControlService, routeGroups
// @Failure 404 {object} error_handler.ErrorResponse "Service not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/service-control/{service} [get]
// @Router /server/{id}/service/{service} [get]
func (ac *ServiceControlController) getStatus(c *fiber.Ctx) error {
id := c.Params("id")
c.Locals("serverId", id)
@@ -78,7 +78,7 @@ func (ac *ServiceControlController) getStatus(c *fiber.Ctx) error {
// @Failure 409 {object} error_handler.ErrorResponse "Service already running"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/service-control/start [post]
// @Router /server/{id}/service/start [post]
func (ac *ServiceControlController) startServer(c *fiber.Ctx) error {
id := c.Params("id")
c.Locals("serverId", id)
@@ -105,7 +105,7 @@ func (ac *ServiceControlController) startServer(c *fiber.Ctx) error {
// @Failure 409 {object} error_handler.ErrorResponse "Service already stopped"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/service-control/stop [post]
// @Router /server/{id}/service/stop [post]
func (ac *ServiceControlController) stopServer(c *fiber.Ctx) error {
id := c.Params("id")
c.Locals("serverId", id)
@@ -131,7 +131,7 @@ func (ac *ServiceControlController) stopServer(c *fiber.Ctx) error {
// @Failure 404 {object} error_handler.ErrorResponse "Service not found"
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
// @Security BearerAuth
// @Router /v1/service-control/restart [post]
// @Router /server/{id}/service/restart [post]
func (ac *ServiceControlController) restartServer(c *fiber.Ctx) error {
id := c.Params("id")
c.Locals("serverId", id)

View File

@@ -42,7 +42,7 @@ func NewStateHistoryController(as *service.StateHistoryService, routeGroups *com
// @Description Return StateHistorys
// @Tags StateHistory
// @Success 200 {array} string
// @Router /v1/state-history [get]
// @Router /state-history [get]
func (ac *StateHistoryController) GetAll(c *fiber.Ctx) error {
var filter model.StateHistoryFilter
if err := common.ParseQueryFilter(c, &filter); err != nil {
@@ -63,7 +63,7 @@ func (ac *StateHistoryController) GetAll(c *fiber.Ctx) error {
// @Description Return StateHistorys
// @Tags StateHistory
// @Success 200 {array} string
// @Router /v1/state-history/statistics [get]
// @Router /state-history/statistics [get]
func (ac *StateHistoryController) GetStatistics(c *fiber.Ctx) error {
var filter model.StateHistoryFilter
if err := common.ParseQueryFilter(c, &filter); err != nil {

View File

@@ -31,9 +31,9 @@ func NewSystemController(routeGroups *common.RouteGroups) *SystemController {
//
// @Summary Return service control status
// @Description Return service control status
// @Tags service-control
// @Tags system
// @Success 200 {array} string
// @Router /v1/service-control [get]
// @Router /system/health [get]
func (ac *SystemController) getFirst(c *fiber.Ctx) error {
return c.SendString(configs.Version)
}

View File

@@ -0,0 +1,138 @@
package controller
import (
"acc-server-manager/local/middleware"
"acc-server-manager/local/service"
"acc-server-manager/local/utl/common"
"acc-server-manager/local/utl/jwt"
"acc-server-manager/local/utl/logging"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid"
)
type WebSocketController struct {
webSocketService *service.WebSocketService
jwtHandler *jwt.OpenJWTHandler
}
func NewWebSocketController(
wsService *service.WebSocketService,
jwtHandler *jwt.OpenJWTHandler,
routeGroups *common.RouteGroups,
auth *middleware.AuthMiddleware,
) *WebSocketController {
wsc := &WebSocketController{
webSocketService: wsService,
jwtHandler: jwtHandler,
}
wsRoutes := routeGroups.WebSocket
wsRoutes.Use("/", wsc.upgradeWebSocket)
wsRoutes.Get("/", websocket.New(wsc.handleWebSocket))
return wsc
}
func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
token := c.Query("token")
if token == "" {
token = c.Get("Authorization")
if token != "" && len(token) > 7 && token[:7] == "Bearer " {
token = token[7:]
}
}
if token == "" {
return fiber.NewError(fiber.StatusUnauthorized, "Missing authentication token")
}
claims, err := wsc.jwtHandler.ValidateToken(token)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid authentication token")
}
userID, err := uuid.Parse(claims.UserID)
if err != nil {
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID in token")
}
c.Locals("userID", userID)
c.Locals("username", claims.UserID)
return c.Next()
}
return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required")
}
func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
connID := uuid.New().String()
userID, ok := c.Locals("userID").(uuid.UUID)
if !ok {
logging.Error("Failed to get user ID from WebSocket connection")
c.Close()
return
}
username, _ := c.Locals("username").(string)
logging.Info("WebSocket connection established for user: %s (ID: %s)", username, userID.String())
wsc.webSocketService.AddConnection(connID, c, &userID)
defer func() {
wsc.webSocketService.RemoveConnection(connID)
logging.Info("WebSocket connection closed for user: %s", username)
}()
for {
messageType, message, err := c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logging.Error("WebSocket error for user %s: %v", username, err)
}
break
}
switch messageType {
case websocket.TextMessage:
wsc.handleTextMessage(connID, userID, message)
case websocket.BinaryMessage:
logging.Debug("Received binary message from user %s (not supported)", username)
case websocket.PingMessage:
if err := c.WriteMessage(websocket.PongMessage, nil); err != nil {
logging.Error("Failed to send pong to user %s: %v", username, err)
break
}
}
}
}
func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUID, message []byte) {
logging.Debug("Received WebSocket message from user %s: %s", userID.String(), string(message))
messageStr := string(message)
if len(messageStr) > 10 && messageStr[:9] == "server_id" {
if serverIDStr := messageStr[10:]; len(serverIDStr) > 0 {
if serverID, err := uuid.Parse(serverIDStr); err == nil {
wsc.webSocketService.SetServerID(connID, serverID)
logging.Info("Associated WebSocket connection %s with server %s", connID, serverID.String())
}
}
}
}
func (wsc *WebSocketController) GetWebSocketUpgrade() fiber.Handler {
return wsc.upgradeWebSocket
}
func (wsc *WebSocketController) GetWebSocketHandler() func(*websocket.Conn) {
return wsc.handleWebSocket
}
func (wsc *WebSocketController) BroadcastServerCreationProgress(serverID uuid.UUID, step string, status string, message string) {
logging.Info("Broadcasting server creation progress: %s - %s: %s", serverID.String(), step, status)
}

View File

@@ -10,12 +10,10 @@ import (
"github.com/google/uuid"
)
// AccessKeyMiddleware provides authentication and permission middleware.
type AccessKeyMiddleware struct {
userInfo CachedUserInfo
}
// NewAccessKeyMiddleware creates a new AccessKeyMiddleware.
func NewAccessKeyMiddleware() *AccessKeyMiddleware {
auth := &AccessKeyMiddleware{
userInfo: CachedUserInfo{UserID: uuid.New().String(), Username: "access_key", RoleName: "Admin", Permissions: map[string]bool{
@@ -25,9 +23,7 @@ func NewAccessKeyMiddleware() *AccessKeyMiddleware {
return auth
}
// Authenticate is a middleware for JWT authentication with enhanced security.
func (m *AccessKeyMiddleware) Authenticate(ctx *fiber.Ctx) error {
// Log authentication attempt
ip := ctx.IP()
userAgent := ctx.Get("User-Agent")

View File

@@ -16,7 +16,6 @@ import (
"github.com/google/uuid"
)
// CachedUserInfo holds cached user authentication and permission data
type CachedUserInfo struct {
UserID string
Username string
@@ -25,34 +24,49 @@ type CachedUserInfo struct {
CachedAt time.Time
}
// AuthMiddleware provides authentication and permission middleware.
type AuthMiddleware struct {
membershipService *service.MembershipService
cache *cache.InMemoryCache
securityMW *security.SecurityMiddleware
jwtHandler *jwt.JWTHandler
openJWTHandler *jwt.OpenJWTHandler
}
// NewAuthMiddleware creates a new AuthMiddleware.
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *AuthMiddleware {
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware {
auth := &AuthMiddleware{
membershipService: ms,
cache: cache,
securityMW: security.NewSecurityMiddleware(),
jwtHandler: jwtHandler,
openJWTHandler: openJWTHandler,
}
// Set up bidirectional relationship for cache invalidation
ms.SetCacheInvalidator(auth)
return auth
}
// Authenticate is a middleware for JWT authentication with enhanced security.
func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error {
return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, true, ctx)
}
func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
// Log authentication attempt
return m.AuthenticateWithHandler(m.jwtHandler, false, ctx)
}
func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isOpenToken bool, ctx *fiber.Ctx) error {
ip := ctx.IP()
userAgent := ctx.Get("User-Agent")
authHeader := ctx.Get("Authorization")
if jwtHandler.IsOpenToken && !isOpenToken {
logging.Error("Authentication failed: attempting to authenticate with open token")
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Wrong token type used",
})
}
if authHeader == "" {
logging.Error("Authentication failed: missing Authorization header from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
@@ -68,7 +82,6 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
})
}
// Validate token length to prevent potential attacks
token := parts[1]
if len(token) < 10 || len(token) > 2048 {
logging.Error("Authentication failed: invalid token length from IP %s", ip)
@@ -77,7 +90,7 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
})
}
claims, err := jwt.ValidateToken(token)
claims, err := jwtHandler.ValidateToken(token)
if err != nil {
logging.Error("Authentication failed: invalid token from IP %s, User-Agent: %s, Error: %v", ip, userAgent, err)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
@@ -85,7 +98,13 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
})
}
// Additional security: validate user ID format
if !jwtHandler.IsOpenToken && claims.IsOpenToken {
logging.Error("Authentication failed: attempting to authenticate with open token")
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "Wrong token type used",
})
}
if claims.UserID == "" || len(claims.UserID) < 10 {
logging.Error("Authentication failed: invalid user ID in token from IP %s", ip)
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
@@ -99,7 +118,6 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
ctx.Locals("userInfo", userInfo)
ctx.Locals("authTime", time.Now())
} else {
// Preload and cache user info to avoid database queries on permission checks
userInfo, err := m.getCachedUserInfo(ctx.UserContext(), claims.UserID)
if err != nil {
logging.Error("Authentication failed: unable to load user info for %s from IP %s: %v", claims.UserID, ip, err)
@@ -117,7 +135,6 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
return ctx.Next()
}
// HasPermission is a middleware for checking user permissions with enhanced logging.
func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
userID, ok := ctx.Locals("userID").(string)
@@ -132,7 +149,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
return ctx.Next()
}
// Validate permission parameter
if requiredPermission == "" {
logging.Error("Permission check failed: empty permission requirement")
return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
@@ -140,7 +156,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
})
}
// Use cached user info from authentication step - no database queries needed
userInfo, ok := ctx.Locals("userInfo").(*CachedUserInfo)
if !ok {
logging.Error("Permission check failed: no cached user info in context from IP %s", ctx.IP())
@@ -149,7 +164,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
})
}
// Check if user has permission using cached data
has := m.hasPermissionFromCache(userInfo, requiredPermission)
if !has {
@@ -164,16 +178,13 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
}
}
// AuthRateLimit applies rate limiting specifically for authentication endpoints
func (m *AuthMiddleware) AuthRateLimit() fiber.Handler {
return m.securityMW.AuthRateLimit()
}
// RequireHTTPS redirects HTTP requests to HTTPS in production
func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
return func(ctx *fiber.Ctx) error {
if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" {
// Allow HTTP in development/testing
if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" {
httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL()
return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently)
@@ -183,11 +194,9 @@ func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
}
}
// getCachedUserInfo retrieves and caches complete user information including permissions
func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (*CachedUserInfo, error) {
cacheKey := fmt.Sprintf("userinfo:%s", userID)
// Try cache first
if cached, found := m.cache.Get(cacheKey); found {
if userInfo, ok := cached.(*CachedUserInfo); ok {
logging.DebugWithContext("AUTH_CACHE", "User info for %s found in cache", userID)
@@ -195,13 +204,11 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
}
}
// Cache miss - load from database
user, err := m.membershipService.GetUserWithPermissions(ctx, userID)
if err != nil {
return nil, err
}
// Build permission map for fast lookups
permissions := make(map[string]bool)
for _, p := range user.Role.Permissions {
permissions[p.Name] = true
@@ -215,34 +222,26 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
CachedAt: time.Now(),
}
// Cache for 15 minutes
m.cache.Set(cacheKey, userInfo, 15*time.Minute)
logging.DebugWithContext("AUTH_CACHE", "User info for %s cached with %d permissions", userID, len(permissions))
return userInfo, nil
}
// hasPermissionFromCache checks permissions using cached user info (no database queries)
func (m *AuthMiddleware) hasPermissionFromCache(userInfo *CachedUserInfo, permission string) bool {
// Super Admin and Admin have all permissions
if userInfo.RoleName == "Super Admin" || userInfo.RoleName == "Admin" {
return true
}
// Check specific permission in cached map
return userInfo.Permissions[permission]
}
// InvalidateUserPermissions removes cached user info for a user
func (m *AuthMiddleware) InvalidateUserPermissions(userID string) {
cacheKey := fmt.Sprintf("userinfo:%s", userID)
m.cache.Delete(cacheKey)
logging.InfoWithContext("AUTH_CACHE", "User info cache invalidated for user %s", userID)
}
// InvalidateAllUserPermissions clears all cached user info (useful for role/permission changes)
func (m *AuthMiddleware) InvalidateAllUserPermissions() {
// This would need to be implemented based on your cache interface
// For now, just log that invalidation was requested
logging.InfoWithContext("AUTH_CACHE", "All user info caches invalidation requested")
}

View File

@@ -7,25 +7,20 @@ import (
"github.com/gofiber/fiber/v2"
)
// RequestLoggingMiddleware logs HTTP requests and responses
type RequestLoggingMiddleware struct {
infoLogger *logging.InfoLogger
}
// NewRequestLoggingMiddleware creates a new request logging middleware
func NewRequestLoggingMiddleware() *RequestLoggingMiddleware {
return &RequestLoggingMiddleware{
infoLogger: logging.GetInfoLogger(),
}
}
// Handler returns the middleware handler function
func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
return func(c *fiber.Ctx) error {
// Record start time
start := time.Now()
// Log incoming request
userAgent := c.Get("User-Agent")
if userAgent == "" {
userAgent = "Unknown"
@@ -33,17 +28,13 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
rlm.infoLogger.LogRequest(c.Method(), c.OriginalURL(), userAgent)
// Continue to next handler
err := c.Next()
// Calculate duration
duration := time.Since(start)
// Log response
statusCode := c.Response().StatusCode()
rlm.infoLogger.LogResponse(c.Method(), c.OriginalURL(), statusCode, duration.String())
// Log error if present
if err != nil {
logging.ErrorWithContext("REQUEST_MIDDLEWARE", "Request failed: %v", err)
}
@@ -52,10 +43,8 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
}
}
// Global request logging middleware instance
var globalRequestLoggingMiddleware *RequestLoggingMiddleware
// GetRequestLoggingMiddleware returns the global request logging middleware
func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
if globalRequestLoggingMiddleware == nil {
globalRequestLoggingMiddleware = NewRequestLoggingMiddleware()
@@ -63,7 +52,6 @@ func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
return globalRequestLoggingMiddleware
}
// Handler returns the global request logging middleware handler
func Handler() fiber.Handler {
return GetRequestLoggingMiddleware().Handler()
}

View File

@@ -1,6 +1,7 @@
package security
import (
"acc-server-manager/local/utl/graceful"
"context"
"fmt"
"strings"
@@ -10,88 +11,76 @@ import (
"github.com/gofiber/fiber/v2"
)
// RateLimiter stores rate limiting information
type RateLimiter struct {
requests map[string][]time.Time
mutex sync.RWMutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter() *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
}
// Clean up old entries every 5 minutes
go rl.cleanup()
shutdownManager := graceful.GetManager()
shutdownManager.RunGoroutine(func(ctx context.Context) {
rl.cleanupWithContext(ctx)
})
return rl
}
// cleanup removes old entries from the rate limiter
func (rl *RateLimiter) cleanup() {
func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mutex.Lock()
now := time.Now()
for key, times := range rl.requests {
// Remove entries older than 1 hour
filtered := make([]time.Time, 0, len(times))
for _, t := range times {
if now.Sub(t) < time.Hour {
filtered = append(filtered, t)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rl.mutex.Lock()
now := time.Now()
for key, times := range rl.requests {
filtered := make([]time.Time, 0, len(times))
for _, t := range times {
if now.Sub(t) < time.Hour {
filtered = append(filtered, t)
}
}
if len(filtered) == 0 {
delete(rl.requests, key)
} else {
rl.requests[key] = filtered
}
}
if len(filtered) == 0 {
delete(rl.requests, key)
} else {
rl.requests[key] = filtered
}
rl.mutex.Unlock()
}
rl.mutex.Unlock()
}
}
// SecurityMiddleware provides comprehensive security middleware
type SecurityMiddleware struct {
rateLimiter *RateLimiter
}
// NewSecurityMiddleware creates a new security middleware
func NewSecurityMiddleware() *SecurityMiddleware {
return &SecurityMiddleware{
rateLimiter: NewRateLimiter(),
}
}
// SecurityHeaders adds security headers to responses
func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler {
return func(c *fiber.Ctx) error {
// Prevent MIME type sniffing
c.Set("X-Content-Type-Options", "nosniff")
// Prevent clickjacking
c.Set("X-Frame-Options", "DENY")
// Enable XSS protection
c.Set("X-XSS-Protection", "1; mode=block")
// Prevent referrer leakage
c.Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy
c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'")
// Permissions Policy
c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()")
return c.Next()
}
}
// RateLimit implements rate limiting for API endpoints
func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ip := c.IP()
@@ -103,7 +92,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
now := time.Now()
requests := sm.rateLimiter.requests[key]
// Remove requests older than duration
filtered := make([]time.Time, 0, len(requests))
for _, t := range requests {
if now.Sub(t) < duration {
@@ -111,7 +99,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
}
}
// Check if limit is exceeded
if len(filtered) >= maxRequests {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Rate limit exceeded",
@@ -119,7 +106,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
})
}
// Add current request
filtered = append(filtered, now)
sm.rateLimiter.requests[key] = filtered
@@ -127,7 +113,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
}
}
// AuthRateLimit implements stricter rate limiting for authentication endpoints
func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
return func(c *fiber.Ctx) error {
ip := c.IP()
@@ -140,7 +125,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
now := time.Now()
requests := sm.rateLimiter.requests[key]
// Remove requests older than 15 minutes
filtered := make([]time.Time, 0, len(requests))
for _, t := range requests {
if now.Sub(t) < 15*time.Minute {
@@ -148,7 +132,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
}
}
// Check if limit is exceeded (5 requests per 15 minutes for auth)
if len(filtered) >= 5 {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
"error": "Too many authentication attempts",
@@ -156,7 +139,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
})
}
// Add current request
filtered = append(filtered, now)
sm.rateLimiter.requests[key] = filtered
@@ -164,20 +146,16 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
}
}
// InputSanitization sanitizes user input to prevent XSS and injection attacks
func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
return func(c *fiber.Ctx) error {
// Sanitize query parameters
c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
sanitized := sanitizeInput(string(value))
c.Request().URI().QueryArgs().Set(string(key), sanitized)
})
// Store original body for processing
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
body := c.Body()
if len(body) > 0 {
// Basic sanitization - remove potentially dangerous patterns
sanitized := sanitizeInput(string(body))
c.Request().SetBodyString(sanitized)
}
@@ -187,15 +165,14 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
}
}
// sanitizeInput removes potentially dangerous patterns from input
func sanitizeInput(input string) string {
// Remove common XSS patterns
dangerous := []string{
"<script",
"</script>",
"javascript:",
"vbscript:",
"data:text/html",
"data:application",
"onload=",
"onerror=",
"onclick=",
@@ -204,28 +181,48 @@ func sanitizeInput(input string) string {
"onblur=",
"onchange=",
"onsubmit=",
"onkeydown=",
"onkeyup=",
"<iframe",
"<object",
"<embed",
"<link",
"<meta",
"<style",
"<form",
"<input",
"<button",
"<svg",
"<math",
"expression(",
"@import",
"url(",
"\\x",
"\\u",
"&#x",
"&#",
}
result := strings.ToLower(input)
result := input
lowerInput := strings.ToLower(input)
for _, pattern := range dangerous {
result = strings.ReplaceAll(result, pattern, "")
if strings.Contains(lowerInput, pattern) {
return ""
}
}
// If the sanitized version is very different, it might be malicious
if len(result) < len(input)/2 {
if strings.Contains(result, "\x00") {
return ""
}
return input
if len(strings.TrimSpace(result)) == 0 && len(input) > 0 {
return ""
}
return result
}
// ValidateContentType ensures only expected content types are accepted
func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
@@ -236,7 +233,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
})
}
// Check if content type is allowed
allowed := false
for _, allowedType := range allowedTypes {
if strings.Contains(contentType, allowedType) {
@@ -256,7 +252,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
}
}
// ValidateUserAgent blocks requests with suspicious or missing user agents
func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
suspiciousAgents := []string{
"sqlmap",
@@ -267,21 +262,19 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
"dirb",
"dirbuster",
"wpscan",
"curl/7.0", // Very old curl versions
"wget/1.0", // Very old wget versions
"curl/7.0",
"wget/1.0",
}
return func(c *fiber.Ctx) error {
userAgent := strings.ToLower(c.Get("User-Agent"))
// Block empty user agents
if userAgent == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "User-Agent header is required",
})
}
// Block suspicious user agents
for _, suspicious := range suspiciousAgents {
if strings.Contains(userAgent, suspicious) {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
@@ -294,7 +287,6 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
}
}
// RequestSizeLimit limits the size of incoming requests
func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
return func(c *fiber.Ctx) error {
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
@@ -311,19 +303,15 @@ func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
}
}
// LogSecurityEvents logs security-related events
func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
return func(c *fiber.Ctx) error {
start := time.Now()
// Process request
err := c.Next()
// Log suspicious activity
status := c.Response().StatusCode()
if status == 401 || status == 403 || status == 429 {
duration := time.Since(start)
// In a real implementation, you would send this to your logging system
fmt.Printf("[SECURITY] %s %s %s %d %v %s\n",
time.Now().Format(time.RFC3339),
c.IP(),
@@ -338,7 +326,6 @@ func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
}
}
// TimeoutMiddleware adds request timeout
func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
@@ -349,3 +336,24 @@ func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Han
return c.Next()
}
}
func (sm *SecurityMiddleware) RequestContextTimeout(timeout time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- c.Next()
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return c.Status(fiber.StatusRequestTimeout).JSON(fiber.Map{
"error": "Request timeout",
})
}
}
}

View File

@@ -9,21 +9,17 @@ import (
"gorm.io/gorm"
)
// Migration001UpgradePasswordSecurity migrates existing user passwords from encrypted to hashed format
type Migration001UpgradePasswordSecurity struct {
DB *gorm.DB
}
// NewMigration001UpgradePasswordSecurity creates a new password security migration
func NewMigration001UpgradePasswordSecurity(db *gorm.DB) *Migration001UpgradePasswordSecurity {
return &Migration001UpgradePasswordSecurity{DB: db}
}
// Up executes the migration
func (m *Migration001UpgradePasswordSecurity) Up() error {
logging.Info("Starting password security upgrade migration...")
// Check if migration has already been applied
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "001_upgrade_password_security").First(&migrationRecord).Error
if err == nil {
@@ -31,12 +27,10 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
return nil
}
// Create migration tracking table if it doesn't exist
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
// Start transaction
tx := m.DB.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
@@ -47,16 +41,13 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
}
}()
// Add a backup column for old passwords (temporary)
if err := tx.Exec("ALTER TABLE users ADD COLUMN password_backup TEXT").Error; err != nil {
// Column might already exist, ignore if it's a duplicate column error
if !isDuplicateColumnError(err) {
tx.Rollback()
return fmt.Errorf("failed to add backup column: %v", err)
}
}
// Get all users with encrypted passwords
var users []UserForMigration
if err := tx.Find(&users).Error; err != nil {
tx.Rollback()
@@ -72,19 +63,15 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
if err := m.migrateUserPassword(tx, &user); err != nil {
logging.Error("Failed to migrate user %s (ID: %s): %v", user.Username, user.ID, err)
failedCount++
// Continue with other users rather than failing completely
continue
}
migratedCount++
}
// Remove backup column after successful migration
if err := tx.Exec("ALTER TABLE users DROP COLUMN password_backup").Error; err != nil {
logging.Error("Failed to remove backup column (non-critical): %v", err)
// Don't fail the migration for this
}
// Record successful migration
migrationRecord = MigrationRecord{
MigrationName: "001_upgrade_password_security",
AppliedAt: "datetime('now')",
@@ -97,7 +84,6 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
return fmt.Errorf("failed to record migration: %v", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
@@ -111,32 +97,24 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
return nil
}
// migrateUserPassword migrates a single user's password
func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, user *UserForMigration) error {
// Skip if password is already hashed (bcrypt hashes start with $2a$, $2b$, or $2y$)
if isAlreadyHashed(user.Password) {
logging.Debug("User %s already has hashed password, skipping", user.Username)
return nil
}
// Backup original password
if err := tx.Model(user).Update("password_backup", user.Password).Error; err != nil {
return fmt.Errorf("failed to backup password: %v", err)
}
// Try to decrypt the old password
var plainPassword string
// First, try to decrypt using the old encryption method
decrypted, err := decryptOldPassword(user.Password)
if err != nil {
// If decryption fails, the password might already be plain text or corrupted
logging.Error("Failed to decrypt password for user %s, treating as plain text: %v", user.Username, err)
// Use original password as-is (might be plain text from development)
plainPassword = user.Password
// Validate it's not obviously encrypted data
if len(plainPassword) > 100 || containsBinaryData(plainPassword) {
return fmt.Errorf("password appears to be corrupted encrypted data")
}
@@ -144,7 +122,6 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
plainPassword = decrypted
}
// Validate plain password
if plainPassword == "" {
return errors.New("decrypted password is empty")
}
@@ -153,13 +130,11 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
return errors.New("password too short after decryption")
}
// Hash the plain password using bcrypt
hashedPassword, err := password.HashPassword(plainPassword)
if err != nil {
return fmt.Errorf("failed to hash password: %v", err)
}
// Update with hashed password
if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil {
return fmt.Errorf("failed to update password: %v", err)
}
@@ -168,19 +143,16 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
return nil
}
// UserForMigration represents a user record for migration purposes
type UserForMigration struct {
ID string `gorm:"column:id"`
Username string `gorm:"column:username"`
Password string `gorm:"column:password"`
}
// TableName specifies the table name for GORM
func (UserForMigration) TableName() string {
return "users"
}
// MigrationRecord tracks applied migrations
type MigrationRecord struct {
ID uint `gorm:"primaryKey"`
MigrationName string `gorm:"unique;not null"`
@@ -189,49 +161,38 @@ type MigrationRecord struct {
Notes string
}
// TableName specifies the table name for GORM
func (MigrationRecord) TableName() string {
return "migration_records"
}
// isAlreadyHashed checks if a password is already bcrypt hashed
func isAlreadyHashed(password string) bool {
return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$")
}
// containsBinaryData checks if a string contains binary data
func containsBinaryData(s string) bool {
for _, b := range []byte(s) {
if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return
if b < 32 && b != 9 && b != 10 && b != 13 {
return true
}
}
return false
}
// isDuplicateColumnError checks if an error is due to duplicate column
func isDuplicateColumnError(err error) bool {
errStr := err.Error()
return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" ||
fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup"
}
// decryptOldPassword attempts to decrypt using the old encryption method
// This is a simplified version of the old DecryptPassword function
func decryptOldPassword(encryptedPassword string) (string, error) {
// This would use the old decryption logic
// For now, we'll return an error to force treating as plain text
// In a real scenario, you'd implement the old decryption here
return "", errors.New("old decryption not implemented - treating as plain text")
}
// Down reverses the migration (if needed)
func (m *Migration001UpgradePasswordSecurity) Down() error {
logging.Error("Password security migration rollback is not supported for security reasons")
return errors.New("password security migration rollback is not supported")
}
// RunMigration is a convenience function to run the migration
func RunPasswordSecurityMigration(db *gorm.DB) error {
migration := NewMigration001UpgradePasswordSecurity(db)
return migration.Up()

View File

@@ -9,21 +9,17 @@ import (
"gorm.io/gorm"
)
// Migration002MigrateToUUID migrates tables from integer IDs to UUIDs
type Migration002MigrateToUUID struct {
DB *gorm.DB
}
// NewMigration002MigrateToUUID creates a new UUID migration
func NewMigration002MigrateToUUID(db *gorm.DB) *Migration002MigrateToUUID {
return &Migration002MigrateToUUID{DB: db}
}
// Up executes the migration
func (m *Migration002MigrateToUUID) Up() error {
logging.Info("Checking UUID migration...")
// Check if migration is needed by looking at the servers table structure
if !m.needsMigration() {
logging.Info("UUID migration not needed - tables already use UUID primary keys")
return nil
@@ -31,7 +27,6 @@ func (m *Migration002MigrateToUUID) Up() error {
logging.Info("Starting UUID migration...")
// Check if migration has already been applied
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
if err == nil {
@@ -39,12 +34,10 @@ func (m *Migration002MigrateToUUID) Up() error {
return nil
}
// Create migration tracking table if it doesn't exist
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
// Execute the UUID migration using the existing migration function
logging.Info("Executing UUID migration...")
if err := runUUIDMigrationSQL(m.DB); err != nil {
return fmt.Errorf("failed to execute UUID migration: %v", err)
@@ -54,9 +47,7 @@ func (m *Migration002MigrateToUUID) Up() error {
return nil
}
// needsMigration checks if the UUID migration is needed by examining table structure
func (m *Migration002MigrateToUUID) needsMigration() bool {
// Check if servers table exists and has integer primary key
var result struct {
Type string `gorm:"column:type"`
}
@@ -67,29 +58,22 @@ func (m *Migration002MigrateToUUID) needsMigration() bool {
`).Scan(&result).Error
if err != nil || result.Type == "" {
// Table doesn't exist or no primary key found - assume no migration needed
return false
}
// If the primary key is INTEGER, we need migration
// If it's TEXT (UUID), migration already done
return result.Type == "INTEGER" || result.Type == "integer"
}
// Down reverses the migration (not implemented for safety)
func (m *Migration002MigrateToUUID) Down() error {
logging.Error("UUID migration rollback is not supported for data safety reasons")
return fmt.Errorf("UUID migration rollback is not supported")
}
// runUUIDMigrationSQL executes the UUID migration using the SQL file
func runUUIDMigrationSQL(db *gorm.DB) error {
// Disable foreign key constraints during migration
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
return fmt.Errorf("failed to disable foreign keys: %v", err)
}
// Start transaction
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
@@ -101,25 +85,21 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
}
}()
// Read the migration SQL from file
sqlPath := filepath.Join("scripts", "migrations", "002_migrate_servers_to_uuid.sql")
migrationSQL, err := ioutil.ReadFile(sqlPath)
if err != nil {
return fmt.Errorf("failed to read migration SQL file: %v", err)
}
// Execute the migration
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to execute migration: %v", err)
}
// Commit transaction
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
// Re-enable foreign key constraints
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
}
@@ -127,7 +107,6 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
return nil
}
// RunUUIDMigration is a convenience function to run the migration
func RunUUIDMigration(db *gorm.DB) error {
migration := NewMigration002MigrateToUUID(db)
return migration.Up()

View File

@@ -0,0 +1,106 @@
package migrations
import (
"acc-server-manager/local/utl/logging"
"fmt"
"gorm.io/gorm"
)
type UpdateStateHistorySessions struct {
DB *gorm.DB
}
func NewUpdateStateHistorySessions(db *gorm.DB) *UpdateStateHistorySessions {
return &UpdateStateHistorySessions{DB: db}
}
func (m *UpdateStateHistorySessions) Up() error {
logging.Info("Checking UUID migration...")
if !m.needsMigration() {
logging.Info("UUID migration not needed - tables already use UUID primary keys")
return nil
}
logging.Info("Starting UUID migration...")
var migrationRecord MigrationRecord
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
if err == nil {
logging.Info("UUID migration already applied, skipping")
return nil
}
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
return fmt.Errorf("failed to create migration tracking table: %v", err)
}
logging.Info("Executing UUID migration...")
if err := runUUIDMigrationSQL(m.DB); err != nil {
return fmt.Errorf("failed to execute UUID migration: %v", err)
}
logging.Info("UUID migration completed successfully")
return nil
}
func (m *UpdateStateHistorySessions) needsMigration() bool {
var result struct {
Exists bool `gorm:"column:exists"`
}
err := m.DB.Raw(`
SELECT count(*) > 0 as exists FROM state_history
WHERE length(session) > 1 LIMIT 1;
`).Scan(&result).Error
if err != nil || !result.Exists {
return false
}
return result.Exists
}
func (m *UpdateStateHistorySessions) Down() error {
logging.Error("UUID migration rollback is not supported for data safety reasons")
return fmt.Errorf("UUID migration rollback is not supported")
}
func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
return fmt.Errorf("failed to disable foreign keys: %v", err)
}
tx := db.Begin()
if tx.Error != nil {
return fmt.Errorf("failed to start transaction: %v", tx.Error)
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
migrationSQL := "UPDATE state_history SET session = upper(substr(session, 1, 1));"
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
tx.Rollback()
return fmt.Errorf("failed to execute migration: %v", err)
}
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("failed to commit migration: %v", err)
}
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
}
return nil
}
func RunUpdateStateHistorySessionsMigration(db *gorm.DB) error {
migration := NewUpdateStateHistorySessions(db)
return migration.Up()
}

View File

@@ -6,20 +6,17 @@ import (
"time"
)
// StatusCache represents a cached server status with expiration
type StatusCache struct {
Status ServiceStatus
UpdatedAt time.Time
}
// CacheConfig holds configuration for cache behavior
type CacheConfig struct {
ExpirationTime time.Duration // How long before a cache entry expires
ThrottleTime time.Duration // Minimum time between status checks
DefaultStatus ServiceStatus // Default status to return when throttled
ExpirationTime time.Duration
ThrottleTime time.Duration
DefaultStatus ServiceStatus
}
// ServerStatusCache manages cached server statuses
type ServerStatusCache struct {
sync.RWMutex
cache map[string]*StatusCache
@@ -27,7 +24,6 @@ type ServerStatusCache struct {
lastChecked map[string]time.Time
}
// NewServerStatusCache creates a new server status cache
func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
return &ServerStatusCache{
cache: make(map[string]*StatusCache),
@@ -36,12 +32,10 @@ func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
}
}
// GetStatus retrieves the cached status or indicates if a fresh check is needed
func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) {
c.RLock()
defer c.RUnlock()
// Check if we're being throttled
if lastCheck, exists := c.lastChecked[serviceName]; exists {
if time.Since(lastCheck) < c.config.ThrottleTime {
if cached, ok := c.cache[serviceName]; ok {
@@ -51,7 +45,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
}
}
// Check if we have a valid cached entry
if cached, ok := c.cache[serviceName]; ok {
if time.Since(cached.UpdatedAt) < c.config.ExpirationTime {
return cached.Status, false
@@ -61,7 +54,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
return StatusUnknown, true
}
// UpdateStatus updates the cache with a new status
func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatus) {
c.Lock()
defer c.Unlock()
@@ -73,7 +65,6 @@ func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatu
c.lastChecked[serviceName] = time.Now()
}
// InvalidateStatus removes a specific service from the cache
func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
c.Lock()
defer c.Unlock()
@@ -82,7 +73,6 @@ func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
delete(c.lastChecked, serviceName)
}
// Clear removes all entries from the cache
func (c *ServerStatusCache) Clear() {
c.Lock()
defer c.Unlock()
@@ -91,13 +81,11 @@ func (c *ServerStatusCache) Clear() {
c.lastChecked = make(map[string]time.Time)
}
// LookupCache provides a generic cache for lookup data
type LookupCache struct {
sync.RWMutex
data map[string]interface{}
}
// NewLookupCache creates a new lookup cache
func NewLookupCache() *LookupCache {
logging.Debug("Initializing new LookupCache")
return &LookupCache{
@@ -105,7 +93,6 @@ func NewLookupCache() *LookupCache {
}
}
// Get retrieves a cached value by key
func (c *LookupCache) Get(key string) (interface{}, bool) {
c.RLock()
defer c.RUnlock()
@@ -119,7 +106,6 @@ func (c *LookupCache) Get(key string) (interface{}, bool) {
return value, exists
}
// Set stores a value in the cache
func (c *LookupCache) Set(key string, value interface{}) {
c.Lock()
defer c.Unlock()
@@ -128,7 +114,6 @@ func (c *LookupCache) Set(key string, value interface{}) {
logging.Debug("Cache SET for key: %s", key)
}
// Clear removes all entries from the cache
func (c *LookupCache) Clear() {
c.Lock()
defer c.Unlock()
@@ -137,13 +122,11 @@ func (c *LookupCache) Clear() {
logging.Debug("Cache CLEARED")
}
// ConfigEntry represents a cached configuration entry with its update time
type ConfigEntry[T any] struct {
Data T
UpdatedAt time.Time
}
// getConfigFromCache is a generic helper function to retrieve cached configs
func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string, expirationTime time.Duration) (*T, bool) {
if entry, ok := cache[serverID]; ok {
if time.Since(entry.UpdatedAt) < expirationTime {
@@ -157,7 +140,6 @@ func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string
return nil, false
}
// updateConfigInCache is a generic helper function to update cached configs
func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID string, data T) {
cache[serverID] = &ConfigEntry[T]{
Data: data,
@@ -166,7 +148,6 @@ func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID strin
logging.Debug("Config cache SET for server ID: %s", serverID)
}
// ServerConfigCache manages cached server configurations
type ServerConfigCache struct {
sync.RWMutex
configuration map[string]*ConfigEntry[Configuration]
@@ -177,7 +158,6 @@ type ServerConfigCache struct {
config CacheConfig
}
// NewServerConfigCache creates a new server configuration cache
func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
logging.Debug("Initializing new ServerConfigCache with expiration time: %v, throttle time: %v", config.ExpirationTime, config.ThrottleTime)
return &ServerConfigCache{
@@ -190,7 +170,6 @@ func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
}
}
// GetConfiguration retrieves a cached configuration
func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, bool) {
c.RLock()
defer c.RUnlock()
@@ -198,7 +177,6 @@ func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, b
return getConfigFromCache(c.configuration, serverID, c.config.ExpirationTime)
}
// GetAssistRules retrieves cached assist rules
func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) {
c.RLock()
defer c.RUnlock()
@@ -206,7 +184,6 @@ func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool)
return getConfigFromCache(c.assistRules, serverID, c.config.ExpirationTime)
}
// GetEvent retrieves cached event configuration
func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
c.RLock()
defer c.RUnlock()
@@ -214,7 +191,6 @@ func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
return getConfigFromCache(c.event, serverID, c.config.ExpirationTime)
}
// GetEventRules retrieves cached event rules
func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
c.RLock()
defer c.RUnlock()
@@ -222,7 +198,6 @@ func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
return getConfigFromCache(c.eventRules, serverID, c.config.ExpirationTime)
}
// GetSettings retrieves cached server settings
func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) {
c.RLock()
defer c.RUnlock()
@@ -230,7 +205,6 @@ func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool)
return getConfigFromCache(c.settings, serverID, c.config.ExpirationTime)
}
// UpdateConfiguration updates the configuration cache
func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configuration) {
c.Lock()
defer c.Unlock()
@@ -238,7 +212,6 @@ func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configur
updateConfigInCache(c.configuration, serverID, config)
}
// UpdateAssistRules updates the assist rules cache
func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules) {
c.Lock()
defer c.Unlock()
@@ -246,7 +219,6 @@ func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules
updateConfigInCache(c.assistRules, serverID, rules)
}
// UpdateEvent updates the event configuration cache
func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
c.Lock()
defer c.Unlock()
@@ -254,7 +226,6 @@ func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
updateConfigInCache(c.event, serverID, event)
}
// UpdateEventRules updates the event rules cache
func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) {
c.Lock()
defer c.Unlock()
@@ -262,7 +233,6 @@ func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules)
updateConfigInCache(c.eventRules, serverID, rules)
}
// UpdateSettings updates the server settings cache
func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSettings) {
c.Lock()
defer c.Unlock()
@@ -270,7 +240,6 @@ func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSetti
updateConfigInCache(c.settings, serverID, settings)
}
// InvalidateServerCache removes all cached configurations for a specific server
func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
c.Lock()
defer c.Unlock()
@@ -283,7 +252,6 @@ func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
delete(c.settings, serverID)
}
// Clear removes all entries from the cache
func (c *ServerConfigCache) Clear() {
c.Lock()
defer c.Unlock()

View File

@@ -13,17 +13,15 @@ import (
type IntString int
type IntBool int
// Config tracks configuration modifications
type Config struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
ConfigFile string `json:"configFile" gorm:"not null"` // e.g. "settings.json"
ConfigFile string `json:"configFile" gorm:"not null"`
OldConfig string `json:"oldConfig" gorm:"type:text"`
NewConfig string `json:"newConfig" gorm:"type:text"`
ChangedAt time.Time `json:"changedAt" gorm:"default:CURRENT_TIMESTAMP"`
}
// BeforeCreate is a GORM hook that runs before creating new config entries
func (c *Config) BeforeCreate(tx *gorm.DB) error {
if c.ID == uuid.Nil {
c.ID = uuid.New()
@@ -79,11 +77,11 @@ type EventConfig struct {
}
type Session struct {
HourOfDay IntString `json:"hourOfDay"`
DayOfWeekend IntString `json:"dayOfWeekend"`
TimeMultiplier IntString `json:"timeMultiplier"`
SessionType string `json:"sessionType"`
SessionDurationMinutes IntString `json:"sessionDurationMinutes"`
HourOfDay IntString `json:"hourOfDay"`
DayOfWeekend IntString `json:"dayOfWeekend"`
TimeMultiplier IntString `json:"timeMultiplier"`
SessionType TrackSession `json:"sessionType"`
SessionDurationMinutes IntString `json:"sessionDurationMinutes"`
}
type AssistRules struct {
@@ -121,8 +119,6 @@ type Configuration struct {
ConfigVersion IntString `json:"configVersion"`
}
// Known configuration keys
func (i *IntBool) UnmarshalJSON(b []byte) error {
var str int
if err := json.Unmarshal(b, &str); err == nil && str <= 1 {

View File

@@ -7,7 +7,6 @@ import (
"gorm.io/gorm"
)
// BaseFilter contains common filter fields that can be embedded in other filters
type BaseFilter struct {
Page int `query:"page"`
PageSize int `query:"page_size"`
@@ -15,18 +14,15 @@ type BaseFilter struct {
SortDesc bool `query:"sort_desc"`
}
// DateRangeFilter adds date range filtering capabilities
type DateRangeFilter struct {
StartDate time.Time `query:"start_date" time_format:"2006-01-02T15:04:05Z07:00"`
EndDate time.Time `query:"end_date" time_format:"2006-01-02T15:04:05Z07:00"`
}
// ServerBasedFilter adds server ID filtering capability
type ServerBasedFilter struct {
ServerID string `param:"id"`
}
// ConfigFilter defines filtering options for Config queries
type ConfigFilter struct {
BaseFilter
ServerBasedFilter
@@ -34,13 +30,11 @@ type ConfigFilter struct {
ChangedAt time.Time `query:"changed_at" time_format:"2006-01-02T15:04:05Z07:00"`
}
// ApiFilter defines filtering options for Api queries
type ServiceControlFilter struct {
BaseFilter
ServiceControl string `query:"serviceControl"`
}
// MembershipFilter defines filtering options for User queries
type MembershipFilter struct {
BaseFilter
Username string `query:"username"`
@@ -48,36 +42,32 @@ type MembershipFilter struct {
RoleID string `query:"role_id"`
}
// Pagination returns the offset and limit for database queries
func (f *BaseFilter) Pagination() (offset, limit int) {
if f.Page < 1 {
f.Page = 1
}
if f.PageSize < 1 {
f.PageSize = 10 // Default page size
f.PageSize = 10
}
offset = (f.Page - 1) * f.PageSize
limit = f.PageSize
return
}
// GetSorting returns the sort field and direction for database queries
func (f *BaseFilter) GetSorting() (field string, desc bool) {
if f.SortBy == "" {
return "id", false // Default sorting
return "id", false
}
return f.SortBy, f.SortDesc
}
// IsDateRangeValid checks if both dates are set and start date is before end date
func (f *DateRangeFilter) IsDateRangeValid() bool {
if f.StartDate.IsZero() || f.EndDate.IsZero() {
return true // If either date is not set, consider it valid
return true
}
return f.StartDate.Before(f.EndDate)
}
// ApplyFilter applies the membership filter to a GORM query
func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
if f.Username != "" {
query = query.Where("username LIKE ?", "%"+f.Username+"%")
@@ -93,12 +83,10 @@ func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
return query
}
// Pagination returns the offset and limit for database queries
func (f *MembershipFilter) Pagination() (offset, limit int) {
return f.BaseFilter.Pagination()
}
// GetSorting returns the sort field and direction for database queries
func (f *MembershipFilter) GetSorting() (field string, desc bool) {
return f.BaseFilter.GetSorting()
}

View File

@@ -1,31 +1,26 @@
package model
// Track represents a track and its capacity
type Track struct {
Name string `json:"track" gorm:"primaryKey;size:50"`
UniquePitBoxes int `json:"unique_pit_boxes"`
PrivateServerSlots int `json:"private_server_slots"`
}
// CarModel represents a car model mapping
type CarModel struct {
Value int `json:"value" gorm:"primaryKey"`
CarModel string `json:"car_model"`
}
// DriverCategory represents driver skill categories
type DriverCategory struct {
Value int `json:"value" gorm:"primaryKey"`
Category string `json:"category"`
}
// CupCategory represents championship cup categories
type CupCategory struct {
Value int `json:"value" gorm:"primaryKey"`
Category string `json:"category"`
}
// SessionType represents session types
type SessionType struct {
Value int `json:"value" gorm:"primaryKey"`
SessionType string `json:"session_type"`

View File

@@ -32,8 +32,6 @@ type BaseModel struct {
DateUpdated time.Time `json:"dateUpdated"`
}
// Init
// Initializes base model with DateCreated, DateUpdated, and Id values.
func (cm *BaseModel) Init() {
date := time.Now()
cm.Id = uuid.NewString()

View File

@@ -5,15 +5,13 @@ import (
"gorm.io/gorm"
)
// Permission represents an action that can be performed in the system.
type Permission struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Name string `json:"name" gorm:"unique_index;not null"`
}
// BeforeCreate is a GORM hook that runs before creating new credentials
func (s *Permission) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
return nil
}
}

View File

@@ -1,6 +1,5 @@
package model
// Permission constants
const (
ServerView = "server.view"
ServerCreate = "server.create"
@@ -27,7 +26,6 @@ const (
MembershipEdit = "membership.edit"
)
// AllPermissions returns a slice of all permission strings.
func AllPermissions() []string {
return []string{
ServerView,

View File

@@ -5,16 +5,14 @@ import (
"gorm.io/gorm"
)
// Role represents a user role in the system.
type Role struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Name string `json:"name" gorm:"unique_index;not null"`
Permissions []Permission `json:"permissions" gorm:"many2many:role_permissions;"`
}
// BeforeCreate is a GORM hook that runs before creating new credentials
func (s *Role) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
return nil
}
}

View File

@@ -16,7 +16,6 @@ const (
ServiceNamePrefix = "ACC-Server"
)
// Server represents an ACC server instance
type ServerAPI struct {
Name string `json:"name"`
Status ServiceStatus `json:"status"`
@@ -35,28 +34,27 @@ func (s *Server) ToServerAPI() *ServerAPI {
}
}
// Server represents an ACC server instance
type Server struct {
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
Name string `gorm:"not null" json:"name"`
Status ServiceStatus `json:"status" gorm:"-"`
IP string `gorm:"not null" json:"-"`
Port int `gorm:"not null" json:"-"`
Path string `gorm:"not null" json:"path"` // e.g. "/acc/servers/server1/"
ServiceName string `gorm:"not null" json:"serviceName"` // Windows service name
Path string `gorm:"not null" json:"path"`
ServiceName string `gorm:"not null" json:"serviceName"`
State *ServerState `gorm:"-" json:"state"`
DateCreated time.Time `json:"dateCreated"`
FromSteamCMD bool `gorm:"not null; default:true" json:"-"`
}
type PlayerState struct {
CarID int // Car ID in broadcast packets
DriverName string // Optional: pulled from registration packet
CarID int
DriverName string
TeamName string
CarModel string
CurrentLap int
LastLapTime int // in milliseconds
BestLapTime int // in milliseconds
LastLapTime int
BestLapTime int
Position int
ConnectedAt time.Time
DisconnectedAt *time.Time
@@ -67,23 +65,18 @@ type State struct {
Session string `json:"session"`
SessionStart time.Time `json:"sessionStart"`
PlayerCount int `json:"playerCount"`
// Players map[int]*PlayerState
// etc.
}
type ServerState struct {
sync.RWMutex `swaggerignore:"-" json:"-"`
Session string `json:"session"`
SessionStart time.Time `json:"sessionStart"`
PlayerCount int `json:"playerCount"`
Track string `json:"track"`
MaxConnections int `json:"maxConnections"`
SessionDurationMinutes int `json:"sessionDurationMinutes"`
// Players map[int]*PlayerState
// etc.
Session TrackSession `json:"session"`
SessionStart time.Time `json:"sessionStart"`
PlayerCount int `json:"playerCount"`
Track string `json:"track"`
MaxConnections int `json:"maxConnections"`
SessionDurationMinutes int `json:"sessionDurationMinutes"`
}
// ServerFilter defines filtering options for Server queries
type ServerFilter struct {
BaseFilter
ServerBasedFilter
@@ -92,9 +85,7 @@ type ServerFilter struct {
Status string `query:"status"`
}
// ApplyFilter implements the Filterable interface
func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
// Apply server filter
if f.ServerID != "" {
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
query = query.Where("id = ?", serverUUID)
@@ -104,18 +95,19 @@ func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
return query
}
// BeforeCreate is a GORM hook that runs before creating a new server
func (s *Server) GenerateUUID() {
if s.ID == uuid.Nil {
s.ID = uuid.New()
}
}
func (s *Server) BeforeCreate(tx *gorm.DB) error {
if s.Name == "" {
return errors.New("server name is required")
}
// Generate UUID if not set
if s.ID == uuid.Nil {
s.ID = uuid.New()
}
s.GenerateUUID()
// Generate service name and config path if not set
if s.ServiceName == "" {
s.ServiceName = s.GenerateServiceName()
}
@@ -123,7 +115,6 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
s.Path = s.GenerateServerPath(BaseServerPath)
}
// Set creation date if not set
if s.DateCreated.IsZero() {
s.DateCreated = time.Now().UTC()
}
@@ -131,19 +122,14 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
return nil
}
// GenerateServiceName creates a unique service name based on the server name
func (s *Server) GenerateServiceName() string {
// If ID is set, use it
if s.ID != uuid.Nil {
return fmt.Sprintf("%s-%s", ServiceNamePrefix, s.ID.String()[:8])
}
// Otherwise use a timestamp-based unique identifier
return fmt.Sprintf("%s-%d", ServiceNamePrefix, time.Now().UnixNano())
}
// GenerateServerPath creates the config path based on the service name
func (s *Server) GenerateServerPath(steamCMDPath string) string {
// Ensure service name is set
if s.ServiceName == "" {
s.ServiceName = s.GenerateServiceName()
}

View File

@@ -17,7 +17,6 @@ const (
StatusRunning
)
// String converts the ServiceStatus to its string representation
func (s ServiceStatus) String() string {
switch s {
case StatusRunning:
@@ -35,7 +34,6 @@ func (s ServiceStatus) String() string {
}
}
// ParseServiceStatus converts a string to ServiceStatus
func ParseServiceStatus(s string) ServiceStatus {
switch s {
case "SERVICE_RUNNING":
@@ -53,31 +51,24 @@ func ParseServiceStatus(s string) ServiceStatus {
}
}
// MarshalJSON implements json.Marshaler interface
func (s ServiceStatus) MarshalJSON() ([]byte, error) {
// Return the numeric value instead of string
return []byte(strconv.Itoa(int(s))), nil
}
// UnmarshalJSON implements json.Unmarshaler interface
func (s *ServiceStatus) UnmarshalJSON(data []byte) error {
// Try to parse as number first
if i, err := strconv.Atoi(string(data)); err == nil {
*s = ServiceStatus(i)
return nil
}
// Fallback to string parsing for backward compatibility
str := string(data)
if len(str) >= 2 {
// Remove quotes if present
str = str[1 : len(str)-1]
}
*s = ParseServiceStatus(str)
return nil
}
// Scan implements the sql.Scanner interface
func (s *ServiceStatus) Scan(value interface{}) error {
if value == nil {
*s = StatusUnknown
@@ -99,7 +90,6 @@ func (s *ServiceStatus) Scan(value interface{}) error {
}
}
// Value implements the driver.Valuer interface
func (s ServiceStatus) Value() (driver.Value, error) {
return s.String(), nil
}

View File

@@ -1,33 +1,31 @@
package model
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// StateHistoryFilter combines common filter capabilities
type StateHistoryFilter struct {
ServerBasedFilter // Adds server ID from path parameter
DateRangeFilter // Adds date range filtering
ServerBasedFilter
DateRangeFilter
// Additional fields specific to state history
Session string `query:"session"`
MinPlayers *int `query:"min_players"`
MaxPlayers *int `query:"max_players"`
Session TrackSession `query:"session"`
MinPlayers *int `query:"min_players"`
MaxPlayers *int `query:"max_players"`
}
// ApplyFilter implements the Filterable interface
func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
// Apply server filter
if f.ServerID != "" {
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
query = query.Where("server_id = ?", serverUUID)
}
}
// Apply date range filter if set
timeZero := time.Time{}
if f.StartDate != timeZero {
query = query.Where("date_created >= ?", f.StartDate)
@@ -36,12 +34,10 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
query = query.Where("date_created <= ?", f.EndDate)
}
// Apply session filter if set
if f.Session != "" {
query = query.Where("session = ?", f.Session)
}
// Apply player count filters if set
if f.MinPlayers != nil {
query = query.Where("player_count >= ?", *f.MinPlayers)
}
@@ -52,19 +48,68 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
return query
}
type StateHistory struct {
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
Session string `json:"session"`
Track string `json:"track"`
PlayerCount int `json:"playerCount"`
DateCreated time.Time `json:"dateCreated"`
SessionStart time.Time `json:"sessionStart"`
SessionDurationMinutes int `json:"sessionDurationMinutes"`
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` // Unique identifier for each session/event
type TrackSession string
const (
SessionPractice TrackSession = "P"
SessionQualify TrackSession = "Q"
SessionRace TrackSession = "R"
SessionUnknown TrackSession = "U"
)
func (i *TrackSession) UnmarshalJSON(b []byte) error {
var str string
if err := json.Unmarshal(b, &str); err == nil {
*i = ToTrackSession(str)
return nil
}
return fmt.Errorf("invalid TrackSession value")
}
func (i TrackSession) Humanize() string {
switch i {
case SessionPractice:
return "Practice"
case SessionQualify:
return "Qualifying"
case SessionRace:
return "Race"
default:
return "Unknown"
}
}
func ToTrackSession(i string) TrackSession {
sessionAbrv := strings.ToUpper(i[:1])
switch sessionAbrv {
case "P":
return SessionPractice
case "Q":
return SessionQualify
case "R":
return SessionRace
default:
return SessionUnknown
}
}
func (i TrackSession) ToString() string {
return string(i)
}
type StateHistory struct {
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
Session TrackSession `json:"session"`
Track string `json:"track"`
PlayerCount int `json:"playerCount"`
DateCreated time.Time `json:"dateCreated"`
SessionStart time.Time `json:"sessionStart"`
SessionDurationMinutes int `json:"sessionDurationMinutes"`
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"`
}
// BeforeCreate is a GORM hook that runs before creating new state history entries
func (sh *StateHistory) BeforeCreate(tx *gorm.DB) error {
if sh.ID == uuid.Nil {
sh.ID = uuid.New()

View File

@@ -1,36 +1,38 @@
package model
import "github.com/google/uuid"
type SessionCount struct {
Name string `json:"name"`
Count int `json:"count"`
Name TrackSession `json:"name"`
Count int `json:"count"`
}
type DailyActivity struct {
Date string `json:"date"`
SessionsCount int `json:"sessionsCount"`
SessionsCount int `json:"sessionsCount"`
}
type PlayerCountPoint struct {
Timestamp string `json:"timestamp"`
Count float64 `json:"count"`
Timestamp string `json:"timestamp"`
Count float64 `json:"count"`
}
type StateHistoryStats struct {
AveragePlayers float64 `json:"averagePlayers"`
PeakPlayers int `json:"peakPlayers"`
TotalSessions int `json:"totalSessions"`
TotalPlaytime int `json:"totalPlaytime" gorm:"-"` // in minutes
AveragePlayers float64 `json:"averagePlayers"`
PeakPlayers int `json:"peakPlayers"`
TotalSessions int `json:"totalSessions"`
TotalPlaytime int `json:"totalPlaytime" gorm:"-"`
PlayerCountOverTime []PlayerCountPoint `json:"playerCountOverTime" gorm:"-"`
SessionTypes []SessionCount `json:"sessionTypes" gorm:"-"`
DailyActivity []DailyActivity `json:"dailyActivity" gorm:"-"`
RecentSessions []RecentSession `json:"recentSessions" gorm:"-"`
}
}
type RecentSession struct {
ID uint `json:"id"`
Date string `json:"date"`
Type string `json:"type"`
Track string `json:"track"`
Duration int `json:"duration"`
Players int `json:"players"`
}
ID uuid.UUID `json:"id"`
Date string `json:"date"`
Type TrackSession `json:"type"`
Track string `json:"track"`
Duration int `json:"duration"`
Players int `json:"players"`
}

168
local/model/steam_2fa.go Normal file
View File

@@ -0,0 +1,168 @@
package model
import (
"fmt"
"sync"
"time"
"github.com/google/uuid"
)
type Steam2FAStatus string
const (
Steam2FAStatusIdle Steam2FAStatus = "idle"
Steam2FAStatusPending Steam2FAStatus = "pending"
Steam2FAStatusComplete Steam2FAStatus = "complete"
Steam2FAStatusError Steam2FAStatus = "error"
)
type Steam2FARequest struct {
ID string `json:"id"`
Status Steam2FAStatus `json:"status"`
Message string `json:"message"`
RequestTime time.Time `json:"requestTime"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
ErrorMsg string `json:"errorMsg,omitempty"`
ServerID *uuid.UUID `json:"serverId,omitempty"`
}
// Steam2FAManager manages 2FA requests and responses
type Steam2FAManager struct {
mu sync.RWMutex
requests map[string]*Steam2FARequest
channels map[string]chan bool
}
func NewSteam2FAManager() *Steam2FAManager {
return &Steam2FAManager{
requests: make(map[string]*Steam2FARequest),
channels: make(map[string]chan bool),
}
}
func (m *Steam2FAManager) CreateRequest(message string, serverID *uuid.UUID) *Steam2FARequest {
m.mu.Lock()
defer m.mu.Unlock()
id := uuid.New().String()
request := &Steam2FARequest{
ID: id,
Status: Steam2FAStatusPending,
Message: message,
RequestTime: time.Now(),
ServerID: serverID,
}
m.requests[id] = request
m.channels[id] = make(chan bool, 1)
return request
}
func (m *Steam2FAManager) GetRequest(id string) (*Steam2FARequest, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
req, exists := m.requests[id]
return req, exists
}
func (m *Steam2FAManager) GetPendingRequests() []*Steam2FARequest {
m.mu.RLock()
defer m.mu.RUnlock()
var pending []*Steam2FARequest
for _, req := range m.requests {
if req.Status == Steam2FAStatusPending {
pending = append(pending, req)
}
}
return pending
}
func (m *Steam2FAManager) CompleteRequest(id string) error {
m.mu.Lock()
defer m.mu.Unlock()
req, exists := m.requests[id]
if !exists {
return fmt.Errorf("request %s not found", id)
}
if req.Status != Steam2FAStatusPending {
return fmt.Errorf("request %s is not pending", id)
}
now := time.Now()
req.Status = Steam2FAStatusComplete
req.CompletedAt = &now
// Signal the waiting goroutine
if ch, exists := m.channels[id]; exists {
select {
case ch <- true:
default:
}
}
return nil
}
func (m *Steam2FAManager) ErrorRequest(id string, errorMsg string) error {
m.mu.Lock()
defer m.mu.Unlock()
req, exists := m.requests[id]
if !exists {
return fmt.Errorf("request %s not found", id)
}
req.Status = Steam2FAStatusError
req.ErrorMsg = errorMsg
// Signal the waiting goroutine with error
if ch, exists := m.channels[id]; exists {
select {
case ch <- false:
default:
}
}
return nil
}
func (m *Steam2FAManager) WaitForCompletion(id string, timeout time.Duration) (bool, error) {
m.mu.RLock()
ch, exists := m.channels[id]
m.mu.RUnlock()
if !exists {
return false, fmt.Errorf("request %s not found", id)
}
select {
case success := <-ch:
return success, nil
case <-time.After(timeout):
// Timeout - mark as error
m.ErrorRequest(id, "timeout waiting for 2FA confirmation")
return false, fmt.Errorf("timeout waiting for 2FA confirmation")
}
}
func (m *Steam2FAManager) CleanupOldRequests(maxAge time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
for id, req := range m.requests {
if req.RequestTime.Before(cutoff) {
delete(m.requests, id)
if ch, exists := m.channels[id]; exists {
close(ch)
delete(m.channels, id)
}
}
}
}

View File

@@ -16,21 +16,18 @@ import (
"gorm.io/gorm"
)
// SteamCredentials represents stored Steam login credentials
type SteamCredentials struct {
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
Username string `gorm:"not null" json:"username"`
Password string `gorm:"not null" json:"-"` // Encrypted, not exposed in JSON
Password string `gorm:"not null" json:"-"`
DateCreated time.Time `json:"dateCreated"`
LastUpdated time.Time `json:"lastUpdated"`
}
// TableName specifies the table name for GORM
func (SteamCredentials) TableName() string {
return "steam_credentials"
}
// BeforeCreate is a GORM hook that runs before creating new credentials
func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
if s.ID == uuid.Nil {
s.ID = uuid.New()
@@ -42,7 +39,6 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
}
s.LastUpdated = now
// Encrypt password before saving
encrypted, err := EncryptPassword(s.Password)
if err != nil {
return err
@@ -52,11 +48,9 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
return nil
}
// BeforeUpdate is a GORM hook that runs before updating credentials
func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
s.LastUpdated = time.Now().UTC()
// Only encrypt if password field is being updated
if tx.Statement.Changed("Password") {
encrypted, err := EncryptPassword(s.Password)
if err != nil {
@@ -68,9 +62,7 @@ func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
return nil
}
// AfterFind is a GORM hook that runs after fetching credentials
func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
// Decrypt password after fetching
if s.Password != "" {
decrypted, err := DecryptPassword(s.Password)
if err != nil {
@@ -81,18 +73,15 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
return nil
}
// Validate checks if the credentials are valid with enhanced security checks
func (s *SteamCredentials) Validate() error {
if s.Username == "" {
return errors.New("username is required")
}
// Enhanced username validation
if len(s.Username) < 3 || len(s.Username) > 64 {
return errors.New("username must be between 3 and 64 characters")
}
// Check for valid characters in username (alphanumeric, underscore, hyphen)
if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched {
return errors.New("username contains invalid characters")
}
@@ -101,7 +90,6 @@ func (s *SteamCredentials) Validate() error {
return errors.New("password is required")
}
// Basic password validation
if len(s.Password) < 6 {
return errors.New("password must be at least 6 characters long")
}
@@ -110,7 +98,6 @@ func (s *SteamCredentials) Validate() error {
return errors.New("password is too long")
}
// Check for obvious weak passwords
weakPasswords := []string{"password", "123456", "steam", "admin", "user"}
lowerPass := strings.ToLower(s.Password)
for _, weak := range weakPasswords {
@@ -122,8 +109,6 @@ func (s *SteamCredentials) Validate() error {
return nil
}
// GetEncryptionKey returns the encryption key from config.
// The key is loaded from the ENCRYPTION_KEY environment variable.
func GetEncryptionKey() []byte {
key := []byte(configs.EncryptionKey)
if len(key) != 32 {
@@ -132,7 +117,6 @@ func GetEncryptionKey() []byte {
return key
}
// EncryptPassword encrypts a password using AES-256-GCM with enhanced security
func EncryptPassword(password string) (string, error) {
if password == "" {
return "", errors.New("password cannot be empty")
@@ -148,33 +132,27 @@ func EncryptPassword(password string) (string, error) {
return "", err
}
// Create a new GCM cipher
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// Create a cryptographically secure nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// Encrypt the password with authenticated encryption
ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil)
// Return base64 encoded encrypted password
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptPassword decrypts an encrypted password with enhanced validation
func DecryptPassword(encryptedPassword string) (string, error) {
if encryptedPassword == "" {
return "", errors.New("encrypted password cannot be empty")
}
// Validate base64 format
if len(encryptedPassword) < 24 { // Minimum reasonable length
if len(encryptedPassword) < 24 {
return "", errors.New("invalid encrypted password format")
}
@@ -184,13 +162,11 @@ func DecryptPassword(encryptedPassword string) (string, error) {
return "", err
}
// Create a new GCM cipher
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
// Decode base64 encoded password
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
if err != nil {
return "", errors.New("invalid base64 encoding")
@@ -207,7 +183,6 @@ func DecryptPassword(encryptedPassword string) (string, error) {
return "", errors.New("decryption failed - invalid ciphertext or key")
}
// Validate decrypted content
decrypted := string(plaintext)
if len(decrypted) == 0 || len(decrypted) > 1024 {
return "", errors.New("invalid decrypted password")

View File

@@ -8,7 +8,6 @@ import (
"gorm.io/gorm"
)
// User represents a user account in the system.
type User struct {
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
Username string `json:"username" gorm:"unique_index;not null"`
@@ -17,16 +16,13 @@ type User struct {
Role Role `json:"role"`
}
// BeforeCreate is a GORM hook that runs before creating new users
func (s *User) BeforeCreate(tx *gorm.DB) error {
s.ID = uuid.New()
// Validate password strength
if err := password.ValidatePasswordStrength(s.Password); err != nil {
return err
}
// Hash password before saving
hashed, err := password.HashPassword(s.Password)
if err != nil {
return err
@@ -36,11 +32,8 @@ func (s *User) BeforeCreate(tx *gorm.DB) error {
return nil
}
// BeforeUpdate is a GORM hook that runs before updating users
func (s *User) BeforeUpdate(tx *gorm.DB) error {
// Only hash if password field is being updated
if tx.Statement.Changed("Password") {
// Validate password strength
if err := password.ValidatePasswordStrength(s.Password); err != nil {
return err
}
@@ -55,14 +48,10 @@ func (s *User) BeforeUpdate(tx *gorm.DB) error {
return nil
}
// AfterFind is a GORM hook that runs after fetching users
func (s *User) AfterFind(tx *gorm.DB) error {
// Password remains hashed - never decrypt
// This hook is kept for potential future use
return nil
}
// Validate checks if the user data is valid
func (s *User) Validate() error {
if s.Username == "" {
return errors.New("username is required")
@@ -73,7 +62,6 @@ func (s *User) Validate() error {
return nil
}
// VerifyPassword verifies a plain text password against the stored hash
func (s *User) VerifyPassword(plainPassword string) error {
return password.VerifyPassword(s.Password, plainPassword)
}

80
local/model/websocket.go Normal file
View File

@@ -0,0 +1,80 @@
package model
import (
"github.com/google/uuid"
)
type ServerCreationStep string
const (
StepValidation ServerCreationStep = "validation"
StepDirectoryCreation ServerCreationStep = "directory_creation"
StepSteamDownload ServerCreationStep = "steam_download"
StepConfigGeneration ServerCreationStep = "config_generation"
StepServiceCreation ServerCreationStep = "service_creation"
StepFirewallRules ServerCreationStep = "firewall_rules"
StepDatabaseSave ServerCreationStep = "database_save"
StepCompleted ServerCreationStep = "completed"
)
type StepStatus string
const (
StatusPending StepStatus = "pending"
StatusInProgress StepStatus = "in_progress"
StatusCompleted StepStatus = "completed"
StatusFailed StepStatus = "failed"
)
type WebSocketMessageType string
const (
MessageTypeStep WebSocketMessageType = "step"
MessageTypeSteamOutput WebSocketMessageType = "steam_output"
MessageTypeError WebSocketMessageType = "error"
MessageTypeComplete WebSocketMessageType = "complete"
)
type WebSocketMessage struct {
Type WebSocketMessageType `json:"type"`
ServerID *uuid.UUID `json:"server_id,omitempty"`
Timestamp int64 `json:"timestamp"`
Data interface{} `json:"data"`
}
type StepMessage struct {
Step ServerCreationStep `json:"step"`
Status StepStatus `json:"status"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
type SteamOutputMessage struct {
Output string `json:"output"`
IsError bool `json:"is_error"`
}
type ErrorMessage struct {
Error string `json:"error"`
Details string `json:"details,omitempty"`
}
type CompleteMessage struct {
ServerID uuid.UUID `json:"server_id"`
Success bool `json:"success"`
Message string `json:"message"`
}
func GetStepDescription(step ServerCreationStep) string {
descriptions := map[ServerCreationStep]string{
StepValidation: "Validating server configuration",
StepDirectoryCreation: "Creating server directories",
StepSteamDownload: "Downloading server files via Steam",
StepConfigGeneration: "Generating server configuration files",
StepServiceCreation: "Creating Windows service",
StepFirewallRules: "Configuring firewall rules",
StepDatabaseSave: "Saving server to database",
StepCompleted: "Server creation completed",
}
return descriptions[step]
}

View File

@@ -7,13 +7,11 @@ import (
"gorm.io/gorm"
)
// BaseRepository provides generic CRUD operations for any model
type BaseRepository[T any, F any] struct {
db *gorm.DB
modelType T
}
// NewBaseRepository creates a new base repository for the given model type
func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] {
return &BaseRepository[T, F]{
db: db,
@@ -21,23 +19,19 @@ func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F]
}
}
// GetAll retrieves all records based on the filter
func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, error) {
result := new([]T)
query := r.db.WithContext(ctx).Model(&r.modelType)
// Apply filter conditions if filter implements Filterable
if filterable, ok := any(filter).(Filterable); ok {
query = filterable.ApplyFilter(query)
}
// Apply pagination if filter implements Pageable
if pageable, ok := any(filter).(Pageable); ok {
offset, limit := pageable.Pagination()
query = query.Offset(offset).Limit(limit)
}
// Apply sorting if filter implements Sortable
if sortable, ok := any(filter).(Sortable); ok {
field, desc := sortable.GetSorting()
if desc {
@@ -54,7 +48,6 @@ func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, err
return result, nil
}
// GetByID retrieves a single record by ID
func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, error) {
result := new(T)
if err := r.db.WithContext(ctx).Where("id = ?", id).First(result).Error; err != nil {
@@ -66,7 +59,6 @@ func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T,
return result, nil
}
// Insert creates a new record
func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
if err := r.db.WithContext(ctx).Create(model).Error; err != nil {
return fmt.Errorf("error creating record: %w", err)
@@ -74,7 +66,6 @@ func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
return nil
}
// Update modifies an existing record
func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
if err := r.db.WithContext(ctx).Save(model).Error; err != nil {
return fmt.Errorf("error updating record: %w", err)
@@ -82,7 +73,6 @@ func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
return nil
}
// Delete removes a record by ID
func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error {
if err := r.db.WithContext(ctx).Delete(new(T), id).Error; err != nil {
return fmt.Errorf("error deleting record: %w", err)
@@ -90,7 +80,6 @@ func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error
return nil
}
// Count returns the total number of records matching the filter
func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&r.modelType)
@@ -106,8 +95,6 @@ func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, err
return count, nil
}
// Interfaces for filter capabilities
type Filterable interface {
ApplyFilter(*gorm.DB) *gorm.DB
}

View File

@@ -17,13 +17,11 @@ func NewConfigRepository(db *gorm.DB) *ConfigRepository {
}
}
// UpdateConfig updates or creates a Config record
func (r *ConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
if err := r.Update(ctx, config); err != nil {
// If update fails, try to insert
if err := r.Insert(ctx, config); err != nil {
return nil
}
}
return config
}
}

View File

@@ -8,20 +8,16 @@ import (
"gorm.io/gorm"
)
// MembershipRepository handles database operations for users, roles, and permissions.
type MembershipRepository struct {
*BaseRepository[model.User, model.MembershipFilter]
}
// NewMembershipRepository creates a new MembershipRepository.
func NewMembershipRepository(db *gorm.DB) *MembershipRepository {
return &MembershipRepository{
BaseRepository: NewBaseRepository[model.User, model.MembershipFilter](db, model.User{}),
}
}
// FindUserByUsername finds a user by their username.
// It preloads the user's role and the role's permissions.
func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username string) (*model.User, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -32,7 +28,6 @@ func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username
return &user, nil
}
// FindUserByIDWithPermissions finds a user by their ID and preloads Role and Permissions.
func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, userID string) (*model.User, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -43,13 +38,11 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context,
return &user, nil
}
// CreateUser creates a new user.
func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error {
db := r.db.WithContext(ctx)
return db.Create(user).Error
}
// FindRoleByName finds a role by its name.
func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) (*model.Role, error) {
var role model.Role
db := r.db.WithContext(ctx)
@@ -60,13 +53,11 @@ func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string)
return &role, nil
}
// CreateRole creates a new role.
func (r *MembershipRepository) CreateRole(ctx context.Context, role *model.Role) error {
db := r.db.WithContext(ctx)
return db.Create(role).Error
}
// FindPermissionByName finds a permission by its name.
func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name string) (*model.Permission, error) {
var permission model.Permission
db := r.db.WithContext(ctx)
@@ -77,19 +68,16 @@ func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name st
return &permission, nil
}
// CreatePermission creates a new permission.
func (r *MembershipRepository) CreatePermission(ctx context.Context, permission *model.Permission) error {
db := r.db.WithContext(ctx)
return db.Create(permission).Error
}
// AssignPermissionsToRole assigns a set of permissions to a role.
func (r *MembershipRepository) AssignPermissionsToRole(ctx context.Context, role *model.Role, permissions []model.Permission) error {
db := r.db.WithContext(ctx)
return db.Model(role).Association("Permissions").Replace(permissions)
}
// GetUserPermissions retrieves all permissions for a given user ID.
func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -106,7 +94,6 @@ func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uu
return permissions, nil
}
// ListUsers retrieves all users.
func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, error) {
var users []*model.User
db := r.db.WithContext(ctx)
@@ -114,13 +101,11 @@ func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, er
return users, err
}
// DeleteUser deletes a user.
func (r *MembershipRepository) DeleteUser(ctx context.Context, userID uuid.UUID) error {
db := r.db.WithContext(ctx)
return db.Delete(&model.User{}, "id = ?", userID).Error
}
// FindUserByID finds a user by their ID.
func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) {
var user model.User
db := r.db.WithContext(ctx)
@@ -131,13 +116,11 @@ func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUI
return &user, nil
}
// UpdateUser updates a user's details in the database.
func (r *MembershipRepository) UpdateUser(ctx context.Context, user *model.User) error {
db := r.db.WithContext(ctx)
return db.Save(user).Error
}
// FindRoleByID finds a role by its ID.
func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUID) (*model.Role, error) {
var role model.Role
db := r.db.WithContext(ctx)
@@ -148,12 +131,10 @@ func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUI
return &role, nil
}
// ListUsersWithFilter retrieves users based on the membership filter.
func (r *MembershipRepository) ListUsersWithFilter(ctx context.Context, filter *model.MembershipFilter) (*[]model.User, error) {
return r.BaseRepository.GetAll(ctx, filter)
}
// ListRoles retrieves all roles.
func (r *MembershipRepository) ListRoles(ctx context.Context) ([]*model.Role, error) {
var roles []*model.Role
db := r.db.WithContext(ctx)

View File

@@ -1,14 +1,16 @@
package repository
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/graceful"
"acc-server-manager/local/utl/logging"
"context"
"time"
"go.uber.org/dig"
)
// InitializeRepositories
// Initializes Dependency Injection modules for repositories
//
// Args:
// *dig.Container: Dig Container
// *dig.Container: Dig Container
func InitializeRepositories(c *dig.Container) {
c.Provide(NewServiceControlRepository)
c.Provide(NewStateHistoryRepository)
@@ -17,4 +19,27 @@ func InitializeRepositories(c *dig.Container) {
c.Provide(NewLookupRepository)
c.Provide(NewSteamCredentialsRepository)
c.Provide(NewMembershipRepository)
if err := c.Provide(func() *model.Steam2FAManager {
manager := model.NewSteam2FAManager()
shutdownManager := graceful.GetManager()
shutdownManager.RunGoroutine(func(ctx context.Context) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
manager.CleanupOldRequests(30 * time.Minute)
}
}
})
return manager
}); err != nil {
logging.Panic("unable to initialize steam 2fa manager")
}
}

View File

@@ -20,10 +20,6 @@ func NewServerRepository(db *gorm.DB) *ServerRepository {
return repo
}
// GetFirstByServiceName
// Gets first row from Server table.
//
// Args:
// context.Context: Application context
// Returns:
// model.ServerModel: Server object from database.

View File

@@ -18,17 +18,14 @@ func NewStateHistoryRepository(db *gorm.DB) *StateHistoryRepository {
}
}
// GetAll retrieves all state history records with the given filter
func (r *StateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
return r.BaseRepository.GetAll(ctx, filter)
}
// Insert creates a new state history record
func (r *StateHistoryRepository) Insert(ctx context.Context, model *model.StateHistory) error {
return r.BaseRepository.Insert(ctx, model)
}
// GetLastSessionID gets the last session ID for a server
func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
var lastSession model.StateHistory
result := r.BaseRepository.db.WithContext(ctx).
@@ -38,7 +35,7 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return uuid.Nil, nil // Return nil UUID if no sessions found
return uuid.Nil, nil
}
return uuid.Nil, result.Error
}
@@ -46,10 +43,8 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
return lastSession.SessionID, nil
}
// GetSummaryStats calculates peak players, total sessions, and average players.
func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
var stats model.StateHistoryStats
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return model.StateHistoryStats{}, err
@@ -73,12 +68,10 @@ func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *mo
return stats, nil
}
// GetTotalPlaytime calculates the total playtime in minutes.
func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
var totalPlaytime struct {
TotalMinutes float64
}
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return 0, err
@@ -100,10 +93,8 @@ func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *m
return int(totalPlaytime.TotalMinutes), nil
}
// GetPlayerCountOverTime gets downsampled player count data.
func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
var points []model.PlayerCountPoint
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return points, err
@@ -122,10 +113,8 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil
return points, err
}
// GetSessionTypes counts sessions by type.
func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
var sessionTypes []model.SessionCount
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return sessionTypes, err
@@ -145,10 +134,8 @@ func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *mo
return sessionTypes, err
}
// GetDailyActivity counts sessions per day.
func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
var dailyActivity []model.DailyActivity
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return dailyActivity, err
@@ -167,10 +154,8 @@ func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *m
return dailyActivity, err
}
// GetRecentSessions retrieves the 10 most recent sessions.
func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
var recentSessions []model.RecentSession
// Parse ServerID to UUID for query
serverUUID, err := uuid.Parse(filter.ServerID)
if err != nil {
return recentSessions, err
@@ -187,7 +172,7 @@ func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *
FROM state_histories
WHERE server_id = ? AND date_created BETWEEN ? AND ?
GROUP BY session_id
HAVING COUNT(*) > 1 AND MAX(player_count) > 0
HAVING MAX(player_count) > 0
ORDER BY date DESC
LIMIT 10
`

View File

@@ -77,8 +77,8 @@ func NewConfigService(repository *repository.ConfigRepository, serverRepository
repository: repository,
serverRepository: serverRepository,
configCache: model.NewServerConfigCache(model.CacheConfig{
ExpirationTime: 5 * time.Minute, // Cache configs for 5 minutes
ThrottleTime: 1 * time.Second, // Prevent rapid re-reads
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
DefaultStatus: model.StatusUnknown,
}),
}
@@ -88,10 +88,6 @@ func (as *ConfigService) SetServerService(serverService *ServerService) {
as.serverService = serverService
}
// UpdateConfig
// Updates physical config file and caches it in database.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -103,7 +99,62 @@ func (as *ConfigService) UpdateConfig(ctx *fiber.Ctx, body *map[string]interface
return as.updateConfigInternal(ctx.UserContext(), serverID, configFile, body, override)
}
// updateConfigInternal handles the actual config update logic without Fiber dependencies
func (as *ConfigService) updateConfigFiles(ctx context.Context, server *model.Server, configFile string, body *map[string]interface{}, override bool) ([]byte, []byte, error) {
if server == nil {
logging.Error("Server not found")
return nil, nil, fmt.Errorf("server not found")
}
configPath := filepath.Join(server.GetConfigPath(), configFile)
oldData, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, nil, err
}
if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil {
return nil, nil, err
}
oldData = []byte("{}")
} else {
return nil, nil, err
}
}
oldDataUTF8, err := DecodeUTF16LEBOM(oldData)
if err != nil {
return nil, nil, err
}
newData, err := json.Marshal(&body)
if err != nil {
return nil, nil, err
}
if !override {
newData, err = jsons.Merge(oldDataUTF8, newData)
if err != nil {
return nil, nil, err
}
}
newData, err = common.IndentJson(newData)
if err != nil {
return nil, nil, err
}
newDataUTF16, err := EncodeUTF16LEBOM(newData)
if err != nil {
return nil, nil, err
}
if err := os.WriteFile(configPath, newDataUTF16, 0644); err != nil {
return nil, nil, err
}
return oldDataUTF8, newData, nil
}
func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID string, configFile string, body *map[string]interface{}, override bool) (*model.Config, error) {
serverUUID, err := uuid.Parse(serverID)
if err != nil {
@@ -117,63 +168,14 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
return nil, fmt.Errorf("server not found")
}
// Read existing config
configPath := filepath.Join(server.GetConfigPath(), configFile)
oldData, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
// Create directory if it doesn't exist
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, err
}
// Create empty JSON file
if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil {
return nil, err
}
oldData = []byte("{}")
} else {
return nil, err
}
}
oldDataUTF8, err := DecodeUTF16LEBOM(oldData)
oldDataUTF8, newData, err := as.updateConfigFiles(ctx, server, configFile, body, override)
if err != nil {
return nil, err
}
// Write new config
newData, err := json.Marshal(&body)
if err != nil {
return nil, err
}
if !override {
newData, err = jsons.Merge(oldDataUTF8, newData)
if err != nil {
return nil, err
}
}
newData, err = common.IndentJson(newData)
if err != nil {
return nil, err
}
newDataUTF16, err := EncodeUTF16LEBOM(newData)
if err != nil {
return nil, err
}
if err := os.WriteFile(configPath, newDataUTF16, 0644); err != nil {
return nil, err
}
// Invalidate all configs for this server since configs can be interdependent
as.configCache.InvalidateServerCache(serverID)
as.serverService.StartAccServerRuntime(server)
// Log change
return as.repository.UpdateConfig(ctx, &model.Config{
ServerID: serverUUID,
ConfigFile: configFile,
@@ -183,10 +185,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
}), nil
}
// GetConfig
// Gets physical config file and caches it in database.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -197,44 +195,47 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
logging.Debug("Getting config for server ID: %s, file: %s", serverIDStr, configFile)
server, err := as.serverRepository.GetByID(ctx.UserContext(), serverIDStr)
if err != nil {
logging.Error("Server not found")
return nil, fiber.NewError(404, "Server not found")
}
return as.getConfigFile(server, configFile)
}
// Try to get from cache based on config file type
func (as *ConfigService) getConfigFile(server *model.Server, configFile string) (interface{}, error) {
serverIDStr := server.ID.String()
switch configFile {
case ConfigurationJson:
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
logging.Debug("Returning cached configuration for server ID: %s", serverIDStr)
return cached, nil
return *cached, nil
}
case AssistRulesJson:
if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok {
logging.Debug("Returning cached assist rules for server ID: %s", serverIDStr)
return cached, nil
return *cached, nil
}
case EventJson:
if cached, ok := as.configCache.GetEvent(serverIDStr); ok {
logging.Debug("Returning cached event config for server ID: %s", serverIDStr)
return cached, nil
return *cached, nil
}
case EventRulesJson:
if cached, ok := as.configCache.GetEventRules(serverIDStr); ok {
logging.Debug("Returning cached event rules for server ID: %s", serverIDStr)
return cached, nil
return *cached, nil
}
case SettingsJson:
if cached, ok := as.configCache.GetSettings(serverIDStr); ok {
logging.Debug("Returning cached settings for server ID: %s", serverIDStr)
return cached, nil
return *cached, nil
}
}
logging.Debug("Cache miss for server ID: %s, file: %s - loading from disk", serverIDStr, configFile)
// Not in cache, load from disk
configPath := filepath.Join(server.GetConfigPath(), configFile)
configPath := server.GetConfigPath()
decoder := DecodeFileName(configFile)
if decoder == nil {
return nil, errors.New("invalid config file")
@@ -244,43 +245,39 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
if err != nil {
if os.IsNotExist(err) {
logging.Debug("Config file not found, creating default for server ID: %s, file: %s", serverIDStr, configFile)
// Return empty config if file doesn't exist
switch configFile {
case ConfigurationJson:
return &model.Configuration{}, nil
return model.Configuration{}, nil
case AssistRulesJson:
return &model.AssistRules{}, nil
return model.AssistRules{}, nil
case EventJson:
return &model.EventConfig{}, nil
return model.EventConfig{}, nil
case EventRulesJson:
return &model.EventRules{}, nil
return model.EventRules{}, nil
case SettingsJson:
return &model.ServerSettings{}, nil
return model.ServerSettings{}, nil
}
}
return nil, err
}
// Cache the loaded config
switch configFile {
case ConfigurationJson:
as.configCache.UpdateConfiguration(serverIDStr, *config.(*model.Configuration))
as.configCache.UpdateConfiguration(serverIDStr, config.(model.Configuration))
case AssistRulesJson:
as.configCache.UpdateAssistRules(serverIDStr, *config.(*model.AssistRules))
as.configCache.UpdateAssistRules(serverIDStr, config.(model.AssistRules))
case EventJson:
as.configCache.UpdateEvent(serverIDStr, *config.(*model.EventConfig))
as.configCache.UpdateEvent(serverIDStr, config.(model.EventConfig))
case EventRulesJson:
as.configCache.UpdateEventRules(serverIDStr, *config.(*model.EventRules))
as.configCache.UpdateEventRules(serverIDStr, config.(model.EventRules))
case SettingsJson:
as.configCache.UpdateSettings(serverIDStr, *config.(*model.ServerSettings))
as.configCache.UpdateSettings(serverIDStr, config.(model.ServerSettings))
}
logging.Debug("Successfully loaded and cached config for server ID: %s, file: %s", serverIDStr, configFile)
return config, nil
}
// GetConfigs
// Gets all configurations for a server, using cache when possible.
func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, error) {
serverID := ctx.Params("id")
@@ -296,82 +293,33 @@ func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, erro
func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configurations, error) {
serverIDStr := server.ID.String()
logging.Info("Loading configs for server ID: %s at path: %s", serverIDStr, server.GetConfigPath())
configs := &model.Configurations{}
// Load configuration
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
logging.Debug("Using cached configuration for server %s", serverIDStr)
configs.Configuration = *cached
} else {
logging.Debug("Loading configuration from disk for server %s", serverIDStr)
config, err := mustDecode[model.Configuration](ConfigurationJson, server.GetConfigPath())
if err != nil {
logging.Error("Failed to load configuration for server %s: %v", serverIDStr, err)
return nil, fmt.Errorf("failed to load configuration: %v", err)
}
configs.Configuration = config
as.configCache.UpdateConfiguration(serverIDStr, config)
settingsConf, err := as.getConfigFile(server, SettingsJson)
if err != nil {
return nil, err
}
// Load assist rules
if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok {
logging.Debug("Using cached assist rules for server %s", serverIDStr)
configs.AssistRules = *cached
} else {
logging.Debug("Loading assist rules from disk for server %s", serverIDStr)
rules, err := mustDecode[model.AssistRules](AssistRulesJson, server.GetConfigPath())
if err != nil {
logging.Error("Failed to load assist rules for server %s: %v", serverIDStr, err)
return nil, fmt.Errorf("failed to load assist rules: %v", err)
}
configs.AssistRules = rules
as.configCache.UpdateAssistRules(serverIDStr, rules)
eventRulesConf, err := as.getConfigFile(server, EventRulesJson)
if err != nil {
return nil, err
}
// Load event config
if cached, ok := as.configCache.GetEvent(serverIDStr); ok {
logging.Debug("Using cached event config for server %s", serverIDStr)
configs.Event = *cached
} else {
logging.Debug("Loading event config from disk for server %s", serverIDStr)
event, err := mustDecode[model.EventConfig](EventJson, server.GetConfigPath())
if err != nil {
logging.Error("Failed to load event config for server %s: %v", serverIDStr, err)
return nil, fmt.Errorf("failed to load event config: %v", err)
}
configs.Event = event
logging.Debug("Updating event config for server %s with track: %s", serverIDStr, event.Track)
as.configCache.UpdateEvent(serverIDStr, event)
eventConf, err := as.getConfigFile(server, EventJson)
if err != nil {
return nil, err
}
// Load event rules
if cached, ok := as.configCache.GetEventRules(serverIDStr); ok {
logging.Debug("Using cached event rules for server %s", serverIDStr)
configs.EventRules = *cached
} else {
logging.Debug("Loading event rules from disk for server %s", serverIDStr)
rules, err := mustDecode[model.EventRules](EventRulesJson, server.GetConfigPath())
if err != nil {
logging.Error("Failed to load event rules for server %s: %v", serverIDStr, err)
return nil, fmt.Errorf("failed to load event rules: %v", err)
}
configs.EventRules = rules
as.configCache.UpdateEventRules(serverIDStr, rules)
assistRulesConf, err := as.getConfigFile(server, AssistRulesJson)
if err != nil {
return nil, err
}
// Load settings
if cached, ok := as.configCache.GetSettings(serverIDStr); ok {
logging.Debug("Using cached settings for server %s", serverIDStr)
configs.Settings = *cached
} else {
logging.Debug("Loading settings from disk for server %s", serverIDStr)
settings, err := mustDecode[model.ServerSettings](SettingsJson, server.GetConfigPath())
if err != nil {
logging.Error("Failed to load settings for server %s: %v", serverIDStr, err)
return nil, fmt.Errorf("failed to load settings: %v", err)
}
configs.Settings = settings
as.configCache.UpdateSettings(serverIDStr, settings)
configurationConf, err := as.getConfigFile(server, ConfigurationJson)
if err != nil {
return nil, err
}
configs := &model.Configurations{
Settings: settingsConf.(model.ServerSettings),
EventRules: eventRulesConf.(model.EventRules),
Event: eventConf.(model.EventConfig),
AssistRules: assistRulesConf.(model.AssistRules),
Configuration: configurationConf.(model.Configuration),
}
logging.Info("Successfully loaded all configs for server %s", serverIDStr)
@@ -396,9 +344,6 @@ func readFile(path string, configFile string) ([]byte, error) {
configPath := filepath.Join(path, configFile)
oldData, err := os.ReadFile(configPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("config file %s does not exist at %s", configFile, configPath)
}
return nil, err
}
return oldData, nil
@@ -475,9 +420,7 @@ func (as *ConfigService) GetConfiguration(server *model.Server) (*model.Configur
return &config, nil
}
// SaveConfiguration saves the configuration for a server
func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.Configuration) error {
// Convert config to map for UpdateConfig
configMap := make(map[string]interface{})
configBytes, err := json.Marshal(config)
if err != nil {
@@ -487,7 +430,6 @@ func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.C
return fmt.Errorf("failed to unmarshal configuration: %v", err)
}
// Update the configuration using the internal method
_, err = as.updateConfigInternal(context.Background(), server.ID.String(), ConfigurationJson, &configMap, true)
_, _, err = as.updateConfigFiles(context.Background(), server, ConfigurationJson, &configMap, true)
return err
}

View File

@@ -96,11 +96,9 @@ func (s *FirewallService) DeleteServerRules(serverName string, tcpPorts, udpPort
}
func (s *FirewallService) UpdateServerRules(serverName string, tcpPorts, udpPorts []int) error {
// First delete existing rules
if err := s.DeleteServerRules(serverName, tcpPorts, udpPorts); err != nil {
return err
}
// Then create new rules
return s.CreateServerRules(serverName, tcpPorts, udpPorts)
}
}

View File

@@ -12,47 +12,57 @@ import (
"github.com/google/uuid"
)
// CacheInvalidator interface for cache invalidation
type CacheInvalidator interface {
InvalidateUserPermissions(userID string)
InvalidateAllUserPermissions()
}
// MembershipService provides business logic for membership-related operations.
type MembershipService struct {
repo *repository.MembershipRepository
cacheInvalidator CacheInvalidator
jwtHandler *jwt.JWTHandler
openJwtHandler *jwt.OpenJWTHandler
}
// NewMembershipService creates a new MembershipService.
func NewMembershipService(repo *repository.MembershipRepository) *MembershipService {
func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService {
return &MembershipService{
repo: repo,
cacheInvalidator: nil, // Will be set later via SetCacheInvalidator
cacheInvalidator: nil,
jwtHandler: jwtHandler,
openJwtHandler: openJwtHandler,
}
}
// SetCacheInvalidator sets the cache invalidator after service initialization
func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) {
s.cacheInvalidator = invalidator
}
// Login authenticates a user and returns a JWT.
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) {
user, err := s.repo.FindUserByUsername(ctx, username)
if err != nil {
return "", errors.New("invalid credentials")
return nil, errors.New("invalid credentials")
}
// Use secure password verification with constant-time comparison
if err := user.VerifyPassword(password); err != nil {
return "", errors.New("invalid credentials")
return nil, errors.New("invalid credentials")
}
return jwt.GenerateToken(user)
return user, nil
}
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
user, err := s.HandleLogin(ctx, username, password)
if err != nil {
return "", err
}
return s.jwtHandler.GenerateToken(user.ID.String())
}
func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string) (string, error) {
return s.openJwtHandler.GenerateToken(userId)
}
// CreateUser creates a new user.
func (s *MembershipService) CreateUser(ctx context.Context, username, password, roleName string) (*model.User, error) {
role, err := s.repo.FindRoleByName(ctx, roleName)
@@ -76,43 +86,35 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password,
return user, nil
}
// ListUsers retrieves all users.
func (s *MembershipService) ListUsers(ctx context.Context) ([]*model.User, error) {
return s.repo.ListUsers(ctx)
}
// GetUser retrieves a single user by ID.
func (s *MembershipService) GetUser(ctx context.Context, userID uuid.UUID) (*model.User, error) {
return s.repo.FindUserByID(ctx, userID)
}
// GetUserWithPermissions retrieves a single user by ID with their role and permissions.
func (s *MembershipService) GetUserWithPermissions(ctx context.Context, userID string) (*model.User, error) {
return s.repo.FindUserByIDWithPermissions(ctx, userID)
}
// UpdateUserRequest defines the request body for updating a user.
type UpdateUserRequest struct {
Username *string `json:"username"`
Password *string `json:"password"`
RoleID *uuid.UUID `json:"roleId"`
}
// DeleteUser deletes a user with validation to prevent Super Admin deletion.
func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) error {
// Get user with role information
user, err := s.repo.FindUserByID(ctx, userID)
if err != nil {
return errors.New("user not found")
}
// Get role to check if it's Super Admin
role, err := s.repo.FindRoleByID(ctx, user.RoleID)
if err != nil {
return errors.New("user role not found")
}
// Prevent deletion of Super Admin users
if role.Name == "Super Admin" {
return errors.New("cannot delete Super Admin user")
}
@@ -122,7 +124,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
return err
}
// Invalidate cache for deleted user
if s.cacheInvalidator != nil {
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
}
@@ -131,7 +132,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
return nil
}
// UpdateUser updates a user's details.
func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, req UpdateUserRequest) (*model.User, error) {
user, err := s.repo.FindUserByID(ctx, userID)
if err != nil {
@@ -143,12 +143,10 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
}
if req.Password != nil && *req.Password != "" {
// Password will be automatically hashed in BeforeUpdate hook
user.Password = *req.Password
}
if req.RoleID != nil {
// Check if role exists
_, err := s.repo.FindRoleByID(ctx, *req.RoleID)
if err != nil {
return nil, errors.New("role not found")
@@ -160,7 +158,6 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
return nil, err
}
// Invalidate cache if role was changed
if req.RoleID != nil && s.cacheInvalidator != nil {
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
}
@@ -169,14 +166,12 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
return user, nil
}
// HasPermission checks if a user has a specific permission.
func (s *MembershipService) HasPermission(ctx context.Context, userID string, permissionName string) (bool, error) {
user, err := s.repo.FindUserByIDWithPermissions(ctx, userID)
if err != nil {
return false, err
}
// Super admin and Admin have all permissions
if user.Role.Name == "Super Admin" || user.Role.Name == "Admin" {
return true, nil
}
@@ -190,15 +185,13 @@ func (s *MembershipService) HasPermission(ctx context.Context, userID string, pe
return false, nil
}
// SetupInitialData creates the initial roles and permissions.
func (s *MembershipService) SetupInitialData(ctx context.Context) error {
// Define all permissions
permissions := model.AllPermissions()
createdPermissions := make([]model.Permission, 0)
for _, pName := range permissions {
perm, err := s.repo.FindPermissionByName(ctx, pName)
if err != nil { // Assuming error means not found
if err != nil {
perm = &model.Permission{Name: pName}
if err := s.repo.CreatePermission(ctx, perm); err != nil {
return err
@@ -207,7 +200,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
createdPermissions = append(createdPermissions, *perm)
}
// Create Super Admin role with all permissions
superAdminRole, err := s.repo.FindRoleByName(ctx, "Super Admin")
if err != nil {
superAdminRole = &model.Role{Name: "Super Admin"}
@@ -219,7 +211,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return err
}
// Create Admin role with same permissions as Super Admin
adminRole, err := s.repo.FindRoleByName(ctx, "Admin")
if err != nil {
adminRole = &model.Role{Name: "Admin"}
@@ -231,7 +222,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return err
}
// Create Manager role with limited permissions (excluding membership, role, user, server create/delete)
managerRole, err := s.repo.FindRoleByName(ctx, "Manager")
if err != nil {
managerRole = &model.Role{Name: "Manager"}
@@ -240,7 +230,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
}
}
// Define manager permissions (limited set)
managerPermissionNames := []string{
model.ServerView,
model.ServerUpdate,
@@ -264,16 +253,14 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return err
}
// Invalidate all caches after role setup changes
if s.cacheInvalidator != nil {
s.cacheInvalidator.InvalidateAllUserPermissions()
}
// Create a default admin user if one doesn't exist
_, err = s.repo.FindUserByUsername(ctx, "admin")
if err != nil {
logging.Debug("Creating default admin user")
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin")
if err != nil {
return err
}
@@ -282,7 +269,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
return nil
}
// GetAllRoles retrieves all roles for dropdown selection.
func (s *MembershipService) GetAllRoles(ctx context.Context) ([]*model.Role, error) {
return s.repo.ListRoles(ctx)
}

View File

@@ -20,7 +20,7 @@ import (
const (
DefaultStartPort = 9600
RequiredPortCount = 1 // Update this if ACC needs more ports
RequiredPortCount = 1
)
type ServerService struct {
@@ -31,6 +31,7 @@ type ServerService struct {
steamService *SteamService
windowsService *WindowsService
firewallService *FirewallService
webSocketService *WebSocketService
instances sync.Map // Track instances per server
lastInsertTimes sync.Map // Track last insert time per server
debouncers sync.Map // Track debounce timers per server
@@ -44,18 +45,15 @@ type pendingState struct {
}
func (s *ServerService) ensureLogTailing(server *model.Server, instance *tracking.AccServerInstance) {
// Check if we already have a tailer
if _, exists := s.logTailers.Load(server.ID); exists {
return
}
// Start tailing in a goroutine that handles file creation/deletion
go func() {
logPath := filepath.Join(server.GetLogPath(), "server.log")
tailer := tracking.NewLogTailer(logPath, instance.HandleLogLine)
s.logTailers.Store(server.ID, tailer)
// Start tailing and automatically handle file changes
tailer.Start()
}()
}
@@ -68,6 +66,7 @@ func NewServerService(
steamService *SteamService,
windowsService *WindowsService,
firewallService *FirewallService,
webSocketService *WebSocketService,
) *ServerService {
service := &ServerService{
repository: repository,
@@ -77,9 +76,9 @@ func NewServerService(
steamService: steamService,
windowsService: windowsService,
firewallService: firewallService,
webSocketService: webSocketService,
}
// Initialize server instances
servers, err := repository.GetAll(context.Background(), &model.ServerFilter{})
if err != nil {
logging.Error("Failed to get servers: %v", err)
@@ -87,7 +86,6 @@ func NewServerService(
}
for i := range *servers {
// Initialize instance regardless of status
logging.Info("Starting server runtime for server ID: %d", (*servers)[i].ID)
service.StartAccServerRuntime(&(*servers)[i])
}
@@ -96,7 +94,7 @@ func NewServerService(
}
func (s *ServerService) shouldInsertStateHistory(serverID uuid.UUID) bool {
insertInterval := 5 * time.Minute // Configure this as needed
insertInterval := 5 * time.Minute
lastInsertInterface, exists := s.lastInsertTimes.Load(serverID)
if !exists {
@@ -119,16 +117,15 @@ func (s *ServerService) getNextSessionID(serverID uuid.UUID) uuid.UUID {
lastID, err := s.stateHistoryRepo.GetLastSessionID(context.Background(), serverID)
if err != nil {
logging.Error("Failed to get last session ID for server %s: %v", serverID, err)
return uuid.New() // Return new UUID as fallback
return uuid.New()
}
if lastID == uuid.Nil {
return uuid.New() // Return new UUID if no previous session
return uuid.New()
}
return uuid.New() // Always generate new UUID for each session
return uuid.New()
}
func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.ServerState) {
// Get or create session ID when session changes
currentSessionInterface, exists := s.instances.Load(serverID)
var sessionID uuid.UUID
if !exists {
@@ -159,8 +156,7 @@ func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.Serv
})
}
func (s *ServerService) updateSessionDuration(server *model.Server, sessionType string) {
// Get configs using helper methods
func (s *ServerService) updateSessionDuration(server *model.Server, sessionType model.TrackSession) {
event, err := s.configService.GetEventConfig(server)
if err != nil {
event = &model.EventConfig{}
@@ -178,9 +174,7 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
serverInstance.State.Track = event.Track
serverInstance.State.MaxConnections = configuration.MaxConnections.ToInt()
// Check if session type has changed
if serverInstance.State.Session != sessionType {
// Get new session ID for the new session
sessionID := s.getNextSessionID(server.ID)
s.sessionIDs.Store(server.ID, sessionID)
}
@@ -201,32 +195,27 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
}
func (s *ServerService) GenerateServerPath(server *model.Server) {
// Get the base steamcmd path from environment variable
steamCMDPath := env.GetSteamCMDDirPath()
server.Path = server.GenerateServerPath(steamCMDPath)
server.FromSteamCMD = true
}
func (s *ServerService) handleStateChange(server *model.Server, state *model.ServerState) {
// Update session duration when session changes
s.updateSessionDuration(server, state.Session)
// Invalidate status cache when server state changes
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
// Cancel existing timer if any
if debouncer, exists := s.debouncers.Load(server.ID); exists {
pending := debouncer.(*pendingState)
pending.timer.Stop()
}
// Create new timer
timer := time.NewTimer(5 * time.Minute)
s.debouncers.Store(server.ID, &pendingState{
timer: timer,
state: state,
})
// Start goroutine to handle the delayed insert
go func() {
<-timer.C
if debouncer, exists := s.debouncers.Load(server.ID); exists {
@@ -236,14 +225,12 @@ func (s *ServerService) handleStateChange(server *model.Server, state *model.Ser
}
}()
// If enough time has passed since last insert, insert immediately
if s.shouldInsertStateHistory(server.ID) {
s.insertStateHistory(server.ID, state)
}
}
func (s *ServerService) StartAccServerRuntime(server *model.Server) {
// Get or create instance
instanceInterface, exists := s.instances.Load(server.ID)
var instance *tracking.AccServerInstance
if !exists {
@@ -255,20 +242,14 @@ func (s *ServerService) StartAccServerRuntime(server *model.Server) {
instance = instanceInterface.(*tracking.AccServerInstance)
}
// Invalidate config cache for this server before loading new configs
serverIDStr := server.ID.String()
s.configService.configCache.InvalidateServerCache(serverIDStr)
s.updateSessionDuration(server, instance.State.Session)
// Ensure log tailing is running (regardless of server status)
s.ensureLogTailing(server, instance)
}
// GetAll
// Gets All rows from Server table.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -300,10 +281,6 @@ func (s *ServerService) GetAll(ctx *fiber.Ctx, filter *model.ServerFilter) (*[]m
return servers, nil
}
// GetById
// Gets rows by ID from Server table.
//
// Args:
// context.Context: Application context
// Returns:
// string: Application version
@@ -330,76 +307,188 @@ func (as *ServerService) GetById(ctx *fiber.Ctx, serverID uuid.UUID) (*model.Ser
return server, nil
}
func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error {
// Validate basic server configuration
func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) error {
logging.Info("create server start")
if err := server.Validate(); err != nil {
logging.Info("create server validation failed")
return err
}
// Install server using SteamCMD
if err := s.steamService.InstallServer(ctx.UserContext(), server.GetServerPath()); err != nil {
return fmt.Errorf("failed to install server: %v", err)
}
s.GenerateServerPath(server)
// Create Windows service with correct paths
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
if err := s.windowsService.CreateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
// Cleanup on failure
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to create Windows service: %v", err)
}
bgCtx := context.Background()
s.configureFirewall(server)
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
if err != nil {
return fmt.Errorf("failed to find available ports: %v", err)
}
// Use the first port for both TCP and UDP
serverPort := ports[0]
tcpPorts := []int{serverPort}
udpPorts := []int{serverPort}
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
// Cleanup on failure
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to create firewall rules: %v", err)
}
// Update server configuration with the allocated port
if err := s.updateServerPort(server, serverPort); err != nil {
return fmt.Errorf("failed to update server configuration: %v", err)
}
// Insert server into database
if err := s.repository.Insert(ctx.UserContext(), server); err != nil {
// Cleanup on failure
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
s.steamService.UninstallServer(server.Path)
return fmt.Errorf("failed to insert server into database: %v", err)
}
// Initialize server runtime
s.StartAccServerRuntime(server)
go func() {
logging.Info("create server start background")
if err := s.createServerBackground(bgCtx, server); err != nil {
logging.Error("Async server creation failed for server %s: %v", server.ID, err)
s.webSocketService.BroadcastError(server.ID, "Server creation failed", err.Error())
s.webSocketService.BroadcastComplete(server.ID, false, fmt.Sprintf("Server creation failed: %v", err))
}
}()
return nil
}
type createServerStep struct {
stepType model.ServerCreationStep
important bool
callback func() (string, error)
description string
}
func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error {
var serverPort int
var tcpPorts, udpPorts []int
steps := []createServerStep{
{
stepType: model.StepValidation,
important: true,
description: "Server configuration validated successfully",
callback: func() (string, error) {
if err := server.Validate(); err != nil {
return "", fmt.Errorf("validation failed: %v", err)
}
return "Server configuration validated successfully", nil
},
},
{
stepType: model.StepDirectoryCreation,
important: true,
description: "Server directories prepared",
callback: func() (string, error) {
return "Server directories prepared", nil
},
},
{
stepType: model.StepSteamDownload,
important: true,
description: "Server files downloaded successfully",
callback: func() (string, error) {
if err := s.steamService.InstallServerWithWebSocket(ctx, server.Path, &server.ID, s.webSocketService); err != nil {
return "", fmt.Errorf("failed to install server: %v", err)
}
return "Server files downloaded successfully", nil
},
},
{
stepType: model.StepConfigGeneration,
important: true,
description: "",
callback: func() (string, error) {
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
if err != nil {
return "", fmt.Errorf("failed to find available ports: %v", err)
}
serverPort = ports[0]
if err := s.updateServerPort(server, serverPort); err != nil {
return "", fmt.Errorf("failed to update server configuration: %v", err)
}
return fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), nil
},
},
{
stepType: model.StepServiceCreation,
important: true,
description: "",
callback: func() (string, error) {
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
return "", fmt.Errorf("failed to create Windows service: %v", err)
}
return fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), nil
},
},
{
stepType: model.StepFirewallRules,
important: false,
description: "",
callback: func() (string, error) {
s.configureFirewall(server)
tcpPorts = []int{serverPort}
udpPorts = []int{serverPort}
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
return "", fmt.Errorf("failed to create firewall rules: %v", err)
}
return fmt.Sprintf("Firewall rules created for port %d", serverPort), nil
},
},
{
stepType: model.StepDatabaseSave,
important: true,
description: "Server saved to database successfully",
callback: func() (string, error) {
if err := s.repository.Insert(ctx, server); err != nil {
return "", fmt.Errorf("failed to insert server into database: %v", err)
}
return "Server saved to database successfully", nil
},
},
}
for i, step := range steps {
s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusInProgress,
model.GetStepDescription(step.stepType), "")
successMessage, err := step.callback()
if err != nil {
s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusFailed,
"", err.Error())
if step.important {
s.rollbackSteps(ctx, server, steps[:i], tcpPorts, udpPorts)
return err
}
}
s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusCompleted,
successMessage, "")
}
s.StartAccServerRuntime(server)
s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted,
model.GetStepDescription(model.StepCompleted), "")
s.webSocketService.BroadcastComplete(server.ID, true,
fmt.Sprintf("Server '%s' created successfully on port %d", server.Name, serverPort))
return nil
}
func (s *ServerService) rollbackSteps(ctx context.Context, server *model.Server, completedSteps []createServerStep, tcpPorts, udpPorts []int) {
for i := len(completedSteps) - 1; i >= 0; i-- {
step := completedSteps[i]
switch step.stepType {
case model.StepDatabaseSave:
s.repository.Delete(ctx, server.ID)
case model.StepFirewallRules:
if len(tcpPorts) > 0 && len(udpPorts) > 0 {
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
}
case model.StepServiceCreation:
s.windowsService.DeleteService(ctx, server.ServiceName)
case model.StepSteamDownload:
s.steamService.UninstallServer(server.Path)
}
}
}
func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
// Get server details
server, err := s.repository.GetByID(ctx.UserContext(), serverID)
if err != nil {
return fmt.Errorf("failed to get server details: %v", err)
}
// Stop and remove Windows service
if err := s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName); err != nil {
logging.Error("Failed to delete Windows service: %v", err)
}
// Remove firewall rules
configuration, err := s.configService.GetConfiguration(server)
if err != nil {
logging.Error("Failed to get configuration for server %d: %v", server.ID, err)
@@ -410,17 +499,14 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
logging.Error("Failed to delete firewall rules: %v", err)
}
// Uninstall server files
if err := s.steamService.UninstallServer(server.Path); err != nil {
logging.Error("Failed to uninstall server: %v", err)
}
// Remove from database
if err := s.repository.Delete(ctx.UserContext(), serverID); err != nil {
return fmt.Errorf("failed to delete server from database: %v", err)
}
// Cleanup runtime resources
if tailer, exists := s.logTailers.Load(server.ID); exists {
tailer.(*tracking.LogTailer).Stop()
s.logTailers.Delete(server.ID)
@@ -430,84 +516,27 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
s.debouncers.Delete(server.ID)
s.sessionIDs.Delete(server.ID)
// Invalidate status cache for deleted server
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
return nil
}
func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error {
// Validate server configuration
if err := server.Validate(); err != nil {
return err
}
// Get existing server details
existingServer, err := s.repository.GetByID(ctx.UserContext(), server.ID)
if err != nil {
return fmt.Errorf("failed to get existing server details: %v", err)
}
// Update server files if path changed
if existingServer.Path != server.Path {
if err := s.steamService.InstallServer(ctx.UserContext(), server.Path); err != nil {
return fmt.Errorf("failed to install server to new location: %v", err)
}
// Clean up old installation
if err := s.steamService.UninstallServer(existingServer.Path); err != nil {
logging.Error("Failed to remove old server installation: %v", err)
}
}
// Update Windows service if necessary
if existingServer.ServiceName != server.ServiceName || existingServer.Path != server.Path {
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
serverWorkingDir := server.GetServerPath()
if err := s.windowsService.UpdateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
return fmt.Errorf("failed to update Windows service: %v", err)
}
}
// Update firewall rules if service name changed
if existingServer.ServiceName != server.ServiceName {
if err := s.configureFirewall(server); err != nil {
return fmt.Errorf("failed to update firewall rules: %v", err)
}
// Invalidate cache for old service name
s.apiService.statusCache.InvalidateStatus(existingServer.ServiceName)
}
// Update database record
if err := s.repository.Update(ctx.UserContext(), server); err != nil {
return fmt.Errorf("failed to update server in database: %v", err)
}
// Restart server runtime
s.StartAccServerRuntime(server)
return nil
}
func (s *ServerService) configureFirewall(server *model.Server) error {
// Find available ports for the server
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
if err != nil {
return fmt.Errorf("failed to find available ports: %v", err)
}
// Use the first port for both TCP and UDP
serverPort := ports[0]
tcpPorts := []int{serverPort}
udpPorts := []int{serverPort}
logging.Info("Configuring firewall for server %d with port %d", server.ID, serverPort)
// Configure firewall rules
if err := s.firewallService.UpdateServerRules(server.Name, tcpPorts, udpPorts); err != nil {
return fmt.Errorf("failed to configure firewall: %v", err)
}
// Update server configuration with the allocated port
if err := s.updateServerPort(server, serverPort); err != nil {
return fmt.Errorf("failed to update server configuration: %v", err)
}
@@ -516,7 +545,6 @@ func (s *ServerService) configureFirewall(server *model.Server) error {
}
func (s *ServerService) updateServerPort(server *model.Server, port int) error {
// Load current configuration
config, err := s.configService.GetConfiguration(server)
if err != nil {
return fmt.Errorf("failed to load server configuration: %v", err)
@@ -525,7 +553,6 @@ func (s *ServerService) updateServerPort(server *model.Server, port int) error {
config.TcpPort = model.IntString(port)
config.UdpPort = model.IntString(port)
// Save the updated configuration
if err := s.configService.SaveConfiguration(server, config); err != nil {
return fmt.Errorf("failed to save server configuration: %v", err)
}

View File

@@ -7,26 +7,22 @@ import (
"go.uber.org/dig"
)
// InitializeServices
// Initializes Dependency Injection modules for services
//
// Args:
// *dig.Container: Dig Container
// *dig.Container: Dig Container
func InitializeServices(c *dig.Container) {
logging.Debug("Initializing repositories")
repository.InitializeRepositories(c)
logging.Debug("Registering services")
// Provide services
c.Provide(NewSteamService)
c.Provide(NewServerService)
c.Provide(NewStateHistoryService)
c.Provide(NewServiceControlService)
c.Provide(NewConfigService)
c.Provide(NewLookupService)
c.Provide(NewSteamService)
c.Provide(NewWindowsService)
c.Provide(NewFirewallService)
c.Provide(NewMembershipService)
c.Provide(NewWebSocketService)
logging.Debug("Initializing service dependencies")
err := c.Invoke(func(server *ServerService, api *ServiceControlService, config *ConfigService) {

View File

@@ -24,9 +24,9 @@ func NewServiceControlService(repository *repository.ServiceControlRepository,
repository: repository,
serverRepository: serverRepository,
statusCache: model.NewServerStatusCache(model.CacheConfig{
ExpirationTime: 30 * time.Second, // Cache expires after 30 seconds
ThrottleTime: 5 * time.Second, // Minimum 5 seconds between checks
DefaultStatus: model.StatusRunning, // Default to running if throttled
ExpirationTime: 30 * time.Second,
ThrottleTime: 5 * time.Second,
DefaultStatus: model.StatusRunning,
}),
windowsService: NewWindowsService(),
}
@@ -42,18 +42,15 @@ func (as *ServiceControlService) GetStatus(ctx *fiber.Ctx) (string, error) {
return "", err
}
// Try to get status from cache
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
return status.String(), nil
}
// If cache miss or expired, check actual status
statusStr, err := as.StatusServer(serviceName)
if err != nil {
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -65,7 +62,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
return "", err
}
// Update status cache for this service before starting
as.statusCache.UpdateStatus(serviceName, model.StatusStarting)
_, err = as.StartServer(serviceName)
@@ -77,7 +73,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -89,7 +84,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
return "", err
}
// Update status cache for this service before stopping
as.statusCache.UpdateStatus(serviceName, model.StatusStopping)
_, err = as.StopServer(serviceName)
@@ -101,7 +95,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -113,7 +106,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
return "", err
}
// Update status cache for this service before restarting
as.statusCache.UpdateStatus(serviceName, model.StatusRestarting)
_, err = as.RestartServer(serviceName)
@@ -125,7 +117,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil
@@ -135,20 +126,16 @@ func (as *ServiceControlService) StatusServer(serviceName string) (string, error
return as.windowsService.Status(context.Background(), serviceName)
}
// GetCachedStatus gets the cached status for a service name without requiring fiber context
func (as *ServiceControlService) GetCachedStatus(serviceName string) (string, error) {
// Try to get status from cache
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
return status.String(), nil
}
// If cache miss or expired, check actual status
statusStr, err := as.StatusServer(serviceName)
if err != nil {
return "", err
}
// Parse and update cache with new status
status := model.ParseServiceStatus(statusStr)
as.statusCache.UpdateStatus(serviceName, status)
return status.String(), nil

View File

@@ -6,7 +6,7 @@ import (
)
type ServiceManager struct {
executor *command.CommandExecutor
executor *command.CommandExecutor
psExecutor *command.CommandExecutor
}
@@ -24,17 +24,14 @@ func NewServiceManager() *ServiceManager {
}
func (s *ServiceManager) ManageService(serviceName, action string) (string, error) {
// Run NSSM command through PowerShell to ensure elevation
output, err := s.psExecutor.ExecuteWithOutput("-nologo", "-noprofile", ".\\nssm", action, serviceName)
if err != nil {
return "", err
}
// Clean up output by removing null bytes and trimming whitespace
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
// Remove \r\n from status strings
cleaned = strings.TrimSuffix(cleaned, "\r\n")
return cleaned, nil
}
@@ -51,11 +48,9 @@ func (s *ServiceManager) Stop(serviceName string) (string, error) {
}
func (s *ServiceManager) Restart(serviceName string) (string, error) {
// First stop the service
if _, err := s.Stop(serviceName); err != nil {
return "", err
}
// Then start it again
return s.Start(serviceName)
}
}

View File

@@ -46,7 +46,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
eg, gCtx := errgroup.WithContext(ctx.UserContext())
// Get Summary Stats (Peak/Avg Players, Total Sessions)
eg.Go(func() error {
summary, err := s.repository.GetSummaryStats(gCtx, filter)
if err != nil {
@@ -61,7 +60,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Total Playtime
eg.Go(func() error {
playtime, err := s.repository.GetTotalPlaytime(gCtx, filter)
if err != nil {
@@ -74,7 +72,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Player Count Over Time
eg.Go(func() error {
playerCount, err := s.repository.GetPlayerCountOverTime(gCtx, filter)
if err != nil {
@@ -87,7 +84,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Session Types
eg.Go(func() error {
sessionTypes, err := s.repository.GetSessionTypes(gCtx, filter)
if err != nil {
@@ -100,7 +96,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Daily Activity
eg.Go(func() error {
dailyActivity, err := s.repository.GetDailyActivity(gCtx, filter)
if err != nil {
@@ -113,7 +108,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
return nil
})
// Get Recent Sessions
eg.Go(func() error {
recentSessions, err := s.repository.GetRecentSessions(gCtx, filter)
if err != nil {

View File

@@ -6,10 +6,15 @@ import (
"acc-server-manager/local/utl/command"
"acc-server-manager/local/utl/env"
"acc-server-manager/local/utl/logging"
"acc-server-manager/local/utl/security"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
)
const (
@@ -17,17 +22,25 @@ const (
)
type SteamService struct {
executor *command.CommandExecutor
repository *repository.SteamCredentialsRepository
executor *command.CommandExecutor
repository *repository.SteamCredentialsRepository
tfaManager *model.Steam2FAManager
pathValidator *security.PathValidator
downloadVerifier *security.DownloadVerifier
}
func NewSteamService(repository *repository.SteamCredentialsRepository) *SteamService {
func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService {
baseExecutor := &command.CommandExecutor{
ExePath: "powershell",
LogOutput: true,
}
return &SteamService{
executor: &command.CommandExecutor{
ExePath: "powershell",
LogOutput: true,
},
repository: repository,
executor: baseExecutor,
repository: repository,
tfaManager: tfaManager,
pathValidator: security.NewPathValidator(),
downloadVerifier: security.NewDownloadVerifier(),
}
}
@@ -42,114 +55,272 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr
return s.repository.Save(ctx, creds)
}
func (s *SteamService) ensureSteamCMD(ctx context.Context) error {
// Get SteamCMD path from environment variable
func (s *SteamService) ensureSteamCMD(_ context.Context) error {
steamCMDPath := env.GetSteamCMDPath()
steamCMDDir := filepath.Dir(steamCMDPath)
// Check if SteamCMD exists
if _, err := os.Stat(steamCMDPath); !os.IsNotExist(err) {
return nil
}
// Create directory if it doesn't exist
if err := os.MkdirAll(steamCMDDir, 0755); err != nil {
return fmt.Errorf("failed to create SteamCMD directory: %v", err)
}
// Download and install SteamCMD
logging.Info("Downloading SteamCMD...")
if err := s.executor.Execute("-Command",
"Invoke-WebRequest -Uri 'https://steamcdn-a.akamaihd.net/client/installer/steamcmd.zip' -OutFile 'steamcmd.zip'"); err != nil {
steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip")
if err := s.downloadVerifier.VerifyAndDownload(
"https://steamcdn-a.akamaihd.net/client/installer/steamcmd.zip",
steamCMDZip,
""); err != nil {
return fmt.Errorf("failed to download SteamCMD: %v", err)
}
// Extract SteamCMD
logging.Info("Extracting SteamCMD...")
if err := s.executor.Execute("-Command",
fmt.Sprintf("Expand-Archive -Path 'steamcmd.zip' -DestinationPath '%s'", steamCMDDir)); err != nil {
return fmt.Errorf("failed to extract SteamCMD: %v", err)
}
// Clean up zip file
os.Remove("steamcmd.zip")
return nil
}
func (s *SteamService) InstallServer(ctx context.Context, installPath string) error {
func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPath string, serverID *uuid.UUID, wsService *WebSocketService) error {
if err := s.ensureSteamCMD(ctx); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
return err
}
// Convert to absolute path and ensure proper Windows path format
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
return fmt.Errorf("invalid installation path: %v", err)
}
absPath, err := filepath.Abs(installPath)
if err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
return fmt.Errorf("failed to get absolute path: %v", err)
}
absPath = filepath.Clean(absPath)
// Ensure install path exists
if err := os.MkdirAll(absPath, 0755); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
return fmt.Errorf("failed to create install directory: %v", err)
}
// Get Steam credentials
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
creds, err := s.GetCredentials(ctx)
if err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
return fmt.Errorf("failed to get Steam credentials: %v", err)
}
// Get SteamCMD path from environment variable
steamCMDPath := env.GetSteamCMDPath()
// Build SteamCMD command
args := []string{
"-nologo",
"-noprofile",
steamCMDPath,
steamCMDArgs := []string{
"+force_install_dir", absPath,
"+login",
}
if creds != nil && creds.Username != "" {
args = append(args, creds.Username)
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false)
steamCMDArgs = append(steamCMDArgs, creds.Username)
if creds.Password != "" {
args = append(args, creds.Password)
steamCMDArgs = append(steamCMDArgs, creds.Password)
}
} else {
args = append(args, "anonymous")
wsService.BroadcastSteamOutput(*serverID, "Using anonymous Steam login", false)
steamCMDArgs = append(steamCMDArgs, "anonymous")
}
args = append(args,
steamCMDArgs = append(steamCMDArgs,
"+app_update", ACCServerAppID,
"validate",
"+quit",
)
// Run SteamCMD
logging.Info("Installing ACC server to %s...", absPath)
if err := s.executor.Execute(args...); err != nil {
args := steamCMDArgs
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
defer cancel()
callbackConfig := &command.CallbackConfig{
OnOutput: func(serverID uuid.UUID, output string, isError bool) {
wsService.BroadcastSteamOutput(serverID, output, isError)
},
OnCommand: func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) {
if completed {
if success {
wsService.BroadcastSteamOutput(serverID, "Command completed successfully", false)
} else {
wsService.BroadcastSteamOutput(serverID, fmt.Sprintf("Command failed: %s", error), true)
}
}
},
}
callbackInteractiveExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbackConfig, *serverID)
callbackInteractiveExecutor.ExePath = steamCMDPath
if err := callbackInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
if timeoutCtx.Err() == context.DeadlineExceeded {
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
}
return fmt.Errorf("failed to run SteamCMD: %v", err)
}
// Add a delay to allow Steam to properly cleanup
logging.Info("Waiting for Steam operations to complete...")
if err := s.executor.Execute("-Command", "Start-Sleep -Seconds 5"); err != nil {
logging.Warn("Failed to wait after Steam operations: %v", err)
}
wsService.BroadcastSteamOutput(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
wsService.BroadcastSteamOutput(*serverID, "Waiting for Steam operations to complete...", false)
time.Sleep(5 * time.Second)
// Verify installation
exePath := filepath.Join(absPath, "server", "accServer.exe")
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
if _, err := os.Stat(exePath); os.IsNotExist(err) {
return fmt.Errorf("server installation failed: accServer.exe not found in %s", absPath)
wsService.BroadcastSteamOutput(*serverID, "accServer.exe not found, checking directory contents...", false)
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", absPath), false)
for _, entry := range entries {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
}
}
serverDir := filepath.Join(absPath, "server")
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
for _, entry := range entries {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
}
} else {
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true)
}
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true)
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
}
logging.Info("Server installation completed successfully")
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false)
return nil
}
func (s *SteamService) UpdateServer(ctx context.Context, installPath string) error {
return s.InstallServer(ctx, installPath) // Same process as install
func (s *SteamService) InstallServerWithCallbacks(ctx context.Context, installPath string, serverID *uuid.UUID, outputCallback command.OutputCallback) error {
if err := s.ensureSteamCMD(ctx); err != nil {
outputCallback(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
return err
}
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
outputCallback(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
return fmt.Errorf("invalid installation path: %v", err)
}
absPath, err := filepath.Abs(installPath)
if err != nil {
outputCallback(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
return fmt.Errorf("failed to get absolute path: %v", err)
}
absPath = filepath.Clean(absPath)
if err := os.MkdirAll(absPath, 0755); err != nil {
outputCallback(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
return fmt.Errorf("failed to create install directory: %v", err)
}
outputCallback(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
creds, err := s.GetCredentials(ctx)
if err != nil {
outputCallback(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
return fmt.Errorf("failed to get Steam credentials: %v", err)
}
steamCMDPath := env.GetSteamCMDPath()
steamCMDArgs := []string{
"+force_install_dir", absPath,
"+login",
}
if creds != nil && creds.Username != "" {
outputCallback(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false)
steamCMDArgs = append(steamCMDArgs, creds.Username)
if creds.Password != "" {
steamCMDArgs = append(steamCMDArgs, creds.Password)
}
} else {
outputCallback(*serverID, "Using anonymous Steam login", false)
steamCMDArgs = append(steamCMDArgs, "anonymous")
}
steamCMDArgs = append(steamCMDArgs,
"+app_update", ACCServerAppID,
"validate",
"+quit",
)
args := steamCMDArgs
outputCallback(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
defer cancel()
callbacks := &command.CallbackConfig{
OnOutput: outputCallback,
}
callbackExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbacks, *serverID)
callbackExecutor.ExePath = steamCMDPath
if err := callbackExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
outputCallback(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
if timeoutCtx.Err() == context.DeadlineExceeded {
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
}
return fmt.Errorf("failed to run SteamCMD: %v", err)
}
outputCallback(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
outputCallback(*serverID, "Waiting for Steam operations to complete...", false)
time.Sleep(5 * time.Second)
exePath := filepath.Join(absPath, "server", "accServer.exe")
outputCallback(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
if _, err := os.Stat(exePath); os.IsNotExist(err) {
outputCallback(*serverID, "accServer.exe not found, checking directory contents...", false)
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", absPath), false)
for _, entry := range entries {
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
}
}
serverDir := filepath.Join(absPath, "server")
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
for _, entry := range entries {
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
}
} else {
outputCallback(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true)
}
outputCallback(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true)
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
}
outputCallback(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false)
return nil
}
func (s *SteamService) UninstallServer(installPath string) error {

185
local/service/websocket.go Normal file
View File

@@ -0,0 +1,185 @@
package service
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/logging"
"encoding/json"
"sync"
"time"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid"
)
type WebSocketConnection struct {
conn *websocket.Conn
serverID *uuid.UUID
userID *uuid.UUID
}
type WebSocketService struct {
connections sync.Map
mu sync.RWMutex
}
func NewWebSocketService() *WebSocketService {
return &WebSocketService{}
}
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
wsConn := &WebSocketConnection{
conn: conn,
userID: userID,
}
ws.connections.Store(connID, wsConn)
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
}
func (ws *WebSocketService) RemoveConnection(connID string) {
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
wsConn.conn.Close()
}
}
logging.Info("WebSocket connection removed: %s", connID)
}
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
if conn, exists := ws.connections.Load(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
wsConn.serverID = &serverID
}
}
}
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
stepMsg := model.StepMessage{
Step: step,
Status: status,
Message: message,
Error: errorMsg,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeStep,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: stepMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
steamMsg := model.SteamOutputMessage{
Output: output,
IsError: isError,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeSteamOutput,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: steamMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
errorMsg := model.ErrorMessage{
Error: error,
Details: details,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeError,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: errorMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
completeMsg := model.CompleteMessage{
ServerID: serverID,
Success: success,
Message: message,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeComplete,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: completeMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
logging.Error("Failed to marshal WebSocket message: %v", err)
return
}
sentToAssociatedConnections := false
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if wsConn.serverID != nil && *wsConn.serverID == serverID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
} else {
sentToAssociatedConnections = true
}
}
}
return true
})
if !sentToAssociatedConnections && (message.Type == model.MessageTypeStep || message.Type == model.MessageTypeError || message.Type == model.MessageTypeComplete) {
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
return true
})
}
}
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
logging.Error("Failed to marshal WebSocket message: %v", err)
return
}
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if wsConn.userID != nil && *wsConn.userID == userID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
}
return true
})
}
func (ws *WebSocketService) GetActiveConnections() int {
count := 0
ws.connections.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}

View File

@@ -23,34 +23,25 @@ func NewWindowsService() *WindowsService {
}
}
// executeNSSM runs an NSSM command through PowerShell with elevation
func (s *WindowsService) ExecuteNSSM(ctx context.Context, args ...string) (string, error) {
// Get NSSM path from environment variable
nssmPath := env.GetNSSMPath()
// Prepend NSSM path to arguments
nssmArgs := append([]string{"-NoProfile", "-NonInteractive", "-Command", "& " + nssmPath}, args...)
output, err := s.executor.ExecuteWithOutput(nssmArgs...)
if err != nil {
// Log the full command and error for debugging
logging.Error("NSSM command failed: powershell %s", strings.Join(nssmArgs, " "))
logging.Error("NSSM error output: %s", output)
return "", err
}
// Clean up output by removing null bytes and trimming whitespace
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
// Remove \r\n from status strings
cleaned = strings.TrimSuffix(cleaned, "\r\n")
return cleaned, nil
}
// Service Installation/Configuration Methods
func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
// Ensure paths are absolute and properly formatted for Windows
absExecPath, err := filepath.Abs(execPath)
if err != nil {
return fmt.Errorf("failed to get absolute path for executable: %v", err)
@@ -63,30 +54,24 @@ func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPat
}
absWorkingDir = filepath.Clean(absWorkingDir)
// Log the paths being used
logging.Info("Creating service '%s' with:", serviceName)
logging.Info(" Executable: %s", absExecPath)
logging.Info(" Working Directory: %s", absWorkingDir)
// First remove any existing service with the same name
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
// Install service
if _, err := s.ExecuteNSSM(ctx, "install", serviceName, absExecPath); err != nil {
return fmt.Errorf("failed to install service: %v", err)
}
// Set arguments if provided
if len(args) > 0 {
cmdArgs := append([]string{"set", serviceName, "AppParameters"}, args...)
if _, err := s.ExecuteNSSM(ctx, cmdArgs...); err != nil {
// Try to clean up on failure
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
return fmt.Errorf("failed to set arguments: %v", err)
}
}
// Verify service was created
if _, err := s.ExecuteNSSM(ctx, "get", serviceName, "Application"); err != nil {
return fmt.Errorf("service creation verification failed: %v", err)
}
@@ -105,17 +90,13 @@ func (s *WindowsService) DeleteService(ctx context.Context, serviceName string)
}
func (s *WindowsService) UpdateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
// First remove the existing service
if err := s.DeleteService(ctx, serviceName); err != nil {
return err
}
// Then create it again with new parameters
return s.CreateService(ctx, serviceName, execPath, workingDir, args)
}
// Service Control Methods
func (s *WindowsService) Status(ctx context.Context, serviceName string) (string, error) {
return s.ExecuteNSSM(ctx, "status", serviceName)
}
@@ -129,11 +110,9 @@ func (s *WindowsService) Stop(ctx context.Context, serviceName string) (string,
}
func (s *WindowsService) Restart(ctx context.Context, serviceName string) (string, error) {
// First stop the service
if _, err := s.Stop(ctx, serviceName); err != nil {
return "", err
}
// Then start it again
return s.Start(ctx, serviceName)
}

64
local/utl/audit/audit.go Normal file
View File

@@ -0,0 +1,64 @@
package audit
import (
"acc-server-manager/local/utl/logging"
"context"
"time"
)
type AuditAction string
const (
ActionLogin AuditAction = "LOGIN"
ActionLogout AuditAction = "LOGOUT"
ActionServerCreate AuditAction = "SERVER_CREATE"
ActionServerUpdate AuditAction = "SERVER_UPDATE"
ActionServerDelete AuditAction = "SERVER_DELETE"
ActionServerStart AuditAction = "SERVER_START"
ActionServerStop AuditAction = "SERVER_STOP"
ActionUserCreate AuditAction = "USER_CREATE"
ActionUserUpdate AuditAction = "USER_UPDATE"
ActionUserDelete AuditAction = "USER_DELETE"
ActionConfigUpdate AuditAction = "CONFIG_UPDATE"
ActionSteamAuth AuditAction = "STEAM_AUTH"
ActionPermissionGrant AuditAction = "PERMISSION_GRANT"
ActionPermissionRevoke AuditAction = "PERMISSION_REVOKE"
)
type AuditEntry struct {
Timestamp time.Time `json:"timestamp"`
UserID string `json:"user_id"`
Username string `json:"username"`
Action AuditAction `json:"action"`
Resource string `json:"resource"`
Details string `json:"details"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
Success bool `json:"success"`
}
func LogAction(ctx context.Context, userID, username string, action AuditAction, resource, details, ipAddress, userAgent string, success bool) {
logging.InfoWithContext("AUDIT", "User %s (%s) performed %s on %s from %s - Success: %t - Details: %s",
username, userID, action, resource, ipAddress, success, details)
}
func LogAuthAction(ctx context.Context, username, ipAddress, userAgent string, success bool, details string) {
action := ActionLogin
if !success {
details = "Failed: " + details
}
LogAction(ctx, "", username, action, "authentication", details, ipAddress, userAgent, success)
}
func LogServerAction(ctx context.Context, userID, username string, action AuditAction, serverID, ipAddress, userAgent string, success bool, details string) {
LogAction(ctx, userID, username, action, "server:"+serverID, details, ipAddress, userAgent, success)
}
func LogUserManagementAction(ctx context.Context, adminUserID, adminUsername string, action AuditAction, targetUserID, ipAddress, userAgent string, success bool, details string) {
LogAction(ctx, adminUserID, adminUsername, action, "user:"+targetUserID, details, ipAddress, userAgent, success)
}
func LogConfigAction(ctx context.Context, userID, username string, configType, ipAddress, userAgent string, success bool, details string) {
LogAction(ctx, userID, username, ActionConfigUpdate, "config:"+configType, details, ipAddress, userAgent, success)
}

View File

@@ -9,26 +9,22 @@ import (
"go.uber.org/dig"
)
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
Expiration int64
}
// InMemoryCache is a thread-safe in-memory cache
type InMemoryCache struct {
items map[string]CacheItem
mu sync.RWMutex
}
// NewInMemoryCache creates and returns a new InMemoryCache instance
func NewInMemoryCache() *InMemoryCache {
return &InMemoryCache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache with an expiration duration (in seconds)
func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
@@ -44,7 +40,6 @@ func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duratio
}
}
// Get retrieves an item from the cache
func (c *InMemoryCache) Get(key string) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -55,24 +50,18 @@ func (c *InMemoryCache) Get(key string) (interface{}, bool) {
}
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
// Item has expired, but don't delete here to avoid lock upgrade.
// It will be overwritten on the next Set.
return nil, false
}
return item.Value, true
}
// Delete removes an item from the cache
func (c *InMemoryCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
// GetOrSet retrieves an item from the cache. If the item is not found, it
// calls the provided function to get the value, sets it in the cache, and
// returns it.
func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) {
if cached, found := c.Get(key); found {
if value, ok := cached.(T); ok {
@@ -90,7 +79,6 @@ func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetch
return value, nil
}
// Start initializes the cache and provides it to the DI container.
func Start(di *dig.Container) {
cache := NewInMemoryCache()
err := di.Provide(func() *InMemoryCache {

View File

@@ -0,0 +1,245 @@
package command
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/logging"
"bufio"
"context"
"fmt"
"io"
"os/exec"
"strings"
"time"
"github.com/google/uuid"
)
type CallbackInteractiveCommandExecutor struct {
*InteractiveCommandExecutor
callbacks *CallbackConfig
serverID uuid.UUID
}
func NewCallbackInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, callbacks *CallbackConfig, serverID uuid.UUID) *CallbackInteractiveCommandExecutor {
if callbacks == nil {
callbacks = DefaultCallbackConfig()
}
return &CallbackInteractiveCommandExecutor{
InteractiveCommandExecutor: &InteractiveCommandExecutor{
CommandExecutor: baseExecutor,
tfaManager: tfaManager,
},
callbacks: callbacks,
serverID: serverID,
}
}
func (e *CallbackInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
cmd := exec.CommandContext(ctx, e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %v", err)
}
defer stdin.Close()
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %v", err)
}
defer stdout.Close()
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %v", err)
}
defer stderr.Close()
logging.Info("Executing interactive command with callbacks: %s %s", e.ExePath, strings.Join(args, " "))
e.callbacks.OnCommand(e.serverID, e.ExePath, args, false, false, "")
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false)
if err := cmd.Start(); err != nil {
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to start command: %v", err), true)
return fmt.Errorf("failed to start command: %v", err)
}
outputDone := make(chan error, 1)
cmdDone := make(chan error, 1)
go e.monitorOutputWithCallbacks(ctx, stdout, stderr, serverID, outputDone)
go func() {
cmdDone <- cmd.Wait()
}()
var cmdErr, outputErr error
completedCount := 0
for completedCount < 2 {
select {
case cmdErr = <-cmdDone:
completedCount++
logging.Info("Command execution completed")
e.callbacks.OnOutput(e.serverID, "Command execution completed", false)
case outputErr = <-outputDone:
completedCount++
logging.Info("Output monitoring completed")
case <-ctx.Done():
e.callbacks.OnOutput(e.serverID, "Command execution cancelled", true)
return ctx.Err()
}
}
if outputErr != nil {
logging.Warn("Output monitoring error: %v", outputErr)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Output monitoring error: %v", outputErr), true)
}
success := cmdErr == nil
errorMsg := ""
if cmdErr != nil {
errorMsg = cmdErr.Error()
}
e.callbacks.OnCommand(e.serverID, e.ExePath, args, true, success, errorMsg)
return cmdErr
}
func (e *CallbackInteractiveCommandExecutor) monitorOutputWithCallbacks(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
defer func() {
select {
case done <- nil:
default:
}
}()
stdoutScanner := bufio.NewScanner(stdout)
stderrScanner := bufio.NewScanner(stderr)
outputChan := make(chan outputLine, 100)
readersDone := make(chan struct{}, 2)
steamConsoleStarted := false
tfaRequestCreated := false
go func() {
defer func() { readersDone <- struct{}{} }()
for stdoutScanner.Scan() {
line := stdoutScanner.Text()
if e.LogOutput {
logging.Info("STDOUT: %s", line)
}
e.callbacks.OnOutput(e.serverID, line, false)
select {
case outputChan <- outputLine{text: line, isError: false}:
case <-ctx.Done():
return
}
}
if err := stdoutScanner.Err(); err != nil {
logging.Warn("Stdout scanner error: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stdout scanner error: %v", err), true)
}
}()
go func() {
defer func() { readersDone <- struct{}{} }()
for stderrScanner.Scan() {
line := stderrScanner.Text()
if e.LogOutput {
logging.Info("STDERR: %s", line)
}
e.callbacks.OnOutput(e.serverID, line, true)
select {
case outputChan <- outputLine{text: line, isError: true}:
case <-ctx.Done():
return
}
}
if err := stderrScanner.Err(); err != nil {
logging.Warn("Stderr scanner error: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stderr scanner error: %v", err), true)
}
}()
readersFinished := 0
for {
select {
case <-ctx.Done():
done <- ctx.Err()
return
case <-readersDone:
readersFinished++
if readersFinished == 2 {
close(outputChan)
for lineData := range outputChan {
if e.is2FAPrompt(lineData.text) {
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
done <- err
return
}
}
}
return
}
case lineData, ok := <-outputChan:
if !ok {
return
}
lowerLine := strings.ToLower(lineData.text)
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
steamConsoleStarted = true
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
e.callbacks.OnOutput(e.serverID, "Steam Console Client startup detected", false)
}
if e.is2FAPrompt(lineData.text) {
if !tfaRequestCreated {
e.callbacks.OnOutput(e.serverID, "2FA prompt detected - waiting for user confirmation", false)
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
done <- err
return
}
tfaRequestCreated = true
}
}
if tfaRequestCreated && e.isSteamContinuing(lineData.text) {
logging.Info("Steam CMD appears to have continued after 2FA confirmation")
e.callbacks.OnOutput(e.serverID, "Steam CMD continued after 2FA confirmation", false)
e.autoCompletePendingRequests(serverID)
}
case <-time.After(15 * time.Second):
if steamConsoleStarted && !tfaRequestCreated {
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
e.callbacks.OnOutput(e.serverID, "Waiting for Steam Guard 2FA confirmation...", false)
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true)
done <- err
return
}
tfaRequestCreated = true
}
}
}
}
type outputLine struct {
text string
isError bool
}

View File

@@ -0,0 +1,19 @@
package command
import "github.com/google/uuid"
type OutputCallback func(serverID uuid.UUID, output string, isError bool)
type CommandCallback func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string)
type CallbackConfig struct {
OnOutput OutputCallback
OnCommand CommandCallback
}
func DefaultCallbackConfig() *CallbackConfig {
return &CallbackConfig{
OnOutput: func(uuid.UUID, string, bool) {},
OnCommand: func(uuid.UUID, string, []string, bool, bool, string) {},
}
}

View File

@@ -8,17 +8,12 @@ import (
"strings"
)
// CommandExecutor provides a base structure for executing commands
type CommandExecutor struct {
// Base executable path
ExePath string
// Working directory for commands
WorkDir string
// Whether to capture and log output
ExePath string
WorkDir string
LogOutput bool
}
// CommandBuilder helps build command arguments
type CommandBuilder struct {
args []string
}
@@ -48,10 +43,9 @@ func (b *CommandBuilder) Build() []string {
return b.args
}
// Execute runs a command with the given arguments
func (e *CommandExecutor) Execute(args ...string) error {
cmd := exec.Command(e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
@@ -65,15 +59,13 @@ func (e *CommandExecutor) Execute(args ...string) error {
return cmd.Run()
}
// ExecuteWithBuilder runs a command using a CommandBuilder
func (e *CommandExecutor) ExecuteWithBuilder(builder *CommandBuilder) error {
return e.Execute(builder.Build()...)
}
// ExecuteWithOutput runs a command and returns its output
func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
cmd := exec.Command(e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
@@ -83,10 +75,9 @@ func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
return string(output), err
}
// ExecuteWithEnv runs a command with custom environment variables
func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
cmd := exec.Command(e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
@@ -100,4 +91,4 @@ func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
logging.Info("Executing command: %s %s", e.ExePath, strings.Join(args, " "))
return cmd.Run()
}
}

View File

@@ -0,0 +1,324 @@
package command
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/logging"
"bufio"
"context"
"fmt"
"io"
"os"
"os/exec"
"strings"
"time"
"github.com/google/uuid"
)
type InteractiveCommandExecutor struct {
*CommandExecutor
tfaManager *model.Steam2FAManager
}
func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager) *InteractiveCommandExecutor {
return &InteractiveCommandExecutor{
CommandExecutor: baseExecutor,
tfaManager: tfaManager,
}
}
func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
cmd := exec.CommandContext(ctx, e.ExePath, args...)
if e.WorkDir != "" {
cmd.Dir = e.WorkDir
}
stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %v", err)
}
defer stdin.Close()
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to create stdout pipe: %v", err)
}
defer stdout.Close()
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to create stderr pipe: %v", err)
}
defer stderr.Close()
logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " "))
debugMode := os.Getenv("STEAMCMD_DEBUG") == "true"
if debugMode {
logging.Info("STEAMCMD_DEBUG mode enabled - will log all output and create proactive 2FA requests")
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start command: %v", err)
}
outputDone := make(chan error, 1)
cmdDone := make(chan error, 1)
go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone)
go func() {
cmdDone <- cmd.Wait()
}()
var cmdErr, outputErr error
completedCount := 0
for completedCount < 2 {
select {
case cmdErr = <-cmdDone:
completedCount++
logging.Info("Command execution completed")
case outputErr = <-outputDone:
completedCount++
logging.Info("Output monitoring completed")
case <-ctx.Done():
return ctx.Err()
}
}
if outputErr != nil {
logging.Warn("Output monitoring error: %v", outputErr)
}
return cmdErr
}
func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
defer func() {
select {
case done <- nil:
default:
}
}()
stdoutScanner := bufio.NewScanner(stdout)
stderrScanner := bufio.NewScanner(stderr)
outputChan := make(chan string, 100)
readersDone := make(chan struct{}, 2)
steamConsoleStarted := false
tfaRequestCreated := false
go func() {
defer func() { readersDone <- struct{}{} }()
for stdoutScanner.Scan() {
line := stdoutScanner.Text()
if e.LogOutput {
logging.Info("STDOUT: %s", line)
}
if strings.Contains(strings.ToLower(line), "steam") {
logging.Info("STEAM_DEBUG: %s", line)
}
select {
case outputChan <- line:
case <-ctx.Done():
return
}
}
if err := stdoutScanner.Err(); err != nil {
logging.Warn("Stdout scanner error: %v", err)
}
}()
go func() {
defer func() { readersDone <- struct{}{} }()
for stderrScanner.Scan() {
line := stderrScanner.Text()
if e.LogOutput {
logging.Info("STDERR: %s", line)
}
if strings.Contains(strings.ToLower(line), "steam") {
logging.Info("STEAM_DEBUG_ERR: %s", line)
}
select {
case outputChan <- line:
case <-ctx.Done():
return
}
}
if err := stderrScanner.Err(); err != nil {
logging.Warn("Stderr scanner error: %v", err)
}
}()
readersFinished := 0
for {
select {
case <-ctx.Done():
done <- ctx.Err()
return
case <-readersDone:
readersFinished++
if readersFinished == 2 {
close(outputChan)
for line := range outputChan {
if e.is2FAPrompt(line) {
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
done <- err
return
}
}
}
return
}
case line, ok := <-outputChan:
if !ok {
return
}
lowerLine := strings.ToLower(line)
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
steamConsoleStarted = true
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
}
if e.is2FAPrompt(line) {
if !tfaRequestCreated {
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
logging.Error("Failed to handle 2FA prompt: %v", err)
done <- err
return
}
tfaRequestCreated = true
}
}
if tfaRequestCreated && e.isSteamContinuing(line) {
logging.Info("Steam CMD appears to have continued after 2FA confirmation - auto-completing 2FA request")
e.autoCompletePendingRequests(serverID)
}
case <-time.After(15 * time.Second):
if steamConsoleStarted && !tfaRequestCreated {
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
done <- err
return
}
tfaRequestCreated = true
} else if !steamConsoleStarted {
logging.Info("No output for 15 seconds (Steam Console not yet started)")
}
}
}
}
func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
twoFAKeywords := []string{
"please enter your steam guard code",
"steam guard",
"two-factor",
"authentication code",
"please check your steam mobile app",
"confirm in application",
"enter the current code from your steam mobile app",
"steam guard mobile authenticator",
"waiting for user info",
"login failure",
"two factor code required",
"enter steam guard code",
"mobile authenticator code",
"authenticator app",
"guard code",
"mobile app",
"confirmation required",
}
lowerLine := strings.ToLower(line)
for _, keyword := range twoFAKeywords {
if strings.Contains(lowerLine, keyword) {
logging.Info("2FA keyword match found: '%s' in line: '%s'", keyword, line)
return true
}
}
waitingPatterns := []string{
"waiting for",
"please enter",
"enter code",
"code:",
"authenticator:",
}
for _, pattern := range waitingPatterns {
if strings.Contains(lowerLine, pattern) {
logging.Info("Potential 2FA waiting pattern found: '%s' in line: '%s'", pattern, line)
return true
}
}
return false
}
func (e *InteractiveCommandExecutor) isSteamContinuing(line string) bool {
lowerLine := strings.ToLower(line)
continuingPatterns := []string{
"loading steam api",
"logging in user",
"waiting for client config",
"waiting for user info",
"update state",
"success! app",
"fully installed",
}
for _, pattern := range continuingPatterns {
if strings.Contains(lowerLine, pattern) {
return true
}
}
return false
}
func (e *InteractiveCommandExecutor) autoCompletePendingRequests(serverID *uuid.UUID) {
if e.tfaManager == nil {
return
}
pendingRequests := e.tfaManager.GetPendingRequests()
for _, req := range pendingRequests {
if req.ServerID != nil && serverID != nil && *req.ServerID == *serverID {
logging.Info("Auto-completing 2FA request %s for server %s", req.ID, serverID.String())
if err := e.tfaManager.CompleteRequest(req.ID); err != nil {
logging.Warn("Failed to auto-complete 2FA request %s: %v", req.ID, err)
}
}
}
}
func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error {
logging.Info("2FA prompt detected: %s", promptLine)
request := e.tfaManager.CreateRequest(promptLine, serverID)
logging.Info("Created 2FA request with ID: %s", request.ID)
timeout := 5 * time.Minute
success, err := e.tfaManager.WaitForCompletion(request.ID, timeout)
if err != nil {
logging.Error("2FA completion failed: %v", err)
return err
}
if !success {
logging.Error("2FA was not completed successfully")
return fmt.Errorf("2FA authentication failed")
}
logging.Info("2FA completed successfully")
return nil
}

View File

@@ -25,6 +25,7 @@ type RouteGroups struct {
StateHistory fiber.Router
Membership fiber.Router
System fiber.Router
WebSocket fiber.Router
}
func CheckError(err error) {
@@ -78,12 +79,6 @@ func IndentJson(body []byte) ([]byte, error) {
return unmarshaledBody.Bytes(), nil
}
// ParseQueryFilter parses query parameters into a filter struct using reflection.
// It supports various field types and uses struct tags to determine parsing behavior.
// Supported tags:
// - `query:"field_name"` - specifies the query parameter name
// - `param:"param_name"` - specifies the path parameter name
// - `time_format:"format"` - specifies the time format for parsing dates (default: RFC3339)
func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
val := reflect.ValueOf(filter)
if val.Kind() != reflect.Ptr || val.IsNil() {
@@ -93,14 +88,12 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
elem := val.Elem()
typ := elem.Type()
// Process all fields including embedded structs
var processFields func(reflect.Value, reflect.Type) error
processFields = func(val reflect.Value, typ reflect.Type) error {
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// Handle embedded structs recursively
if fieldType.Anonymous {
if err := processFields(field, fieldType.Type); err != nil {
return err
@@ -108,12 +101,10 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
continue
}
// Skip if field cannot be set
if !field.CanSet() {
continue
}
// Check for param tag first (path parameters)
if paramName := fieldType.Tag.Get("param"); paramName != "" {
if err := parsePathParam(c, field, paramName); err != nil {
return fmt.Errorf("error parsing path parameter %s: %v", paramName, err)
@@ -121,15 +112,14 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
continue
}
// Then check for query tag
queryName := fieldType.Tag.Get("query")
if queryName == "" {
queryName = ToSnakeCase(fieldType.Name) // Default to snake_case of field name
queryName = ToSnakeCase(fieldType.Name)
}
queryVal := c.Query(queryName)
if queryVal == "" {
continue // Skip empty values
continue
}
if err := parseValue(field, queryVal, fieldType.Tag); err != nil {

View File

@@ -8,7 +8,7 @@ import (
)
var (
Version = "0.10.2"
Version = "0.10.7"
Prefix = "v1"
Secret string
SecretCode string
@@ -18,7 +18,6 @@ var (
func Init() {
godotenv.Load()
// Fail fast if critical environment variables are missing
Secret = getEnvRequired("APP_SECRET")
SecretCode = getEnvRequired("APP_SECRET_CODE")
EncryptionKey = getEnvRequired("ENCRYPTION_KEY")
@@ -29,7 +28,6 @@ func Init() {
}
}
// getEnv retrieves an environment variable or returns a fallback value.
func getEnv(key, fallback string) string {
if value, exists := os.LookupEnv(key); exists {
return value
@@ -38,12 +36,10 @@ func getEnv(key, fallback string) string {
return fallback
}
// getEnvRequired retrieves an environment variable and fails if it's not set.
// This should be used for critical configuration that must not have defaults.
func getEnvRequired(key string) string {
if value, exists := os.LookupEnv(key); exists && value != "" {
return value
}
log.Fatalf("Required environment variable %s is not set or is empty", key)
return "" // This line will never be reached due to log.Fatalf
return ""
}

View File

@@ -33,7 +33,6 @@ func Start(di *dig.Container) {
func Migrate(db *gorm.DB) {
logging.Info("Migrating database")
// Run GORM AutoMigrate for all models
err := db.AutoMigrate(
&model.ServiceControlModel{},
&model.Config{},
@@ -52,7 +51,6 @@ func Migrate(db *gorm.DB) {
if err != nil {
logging.Error("GORM AutoMigrate failed: %v", err)
// Don't panic, just log the error as custom migrations may have handled this
}
db.FirstOrCreate(&model.ServiceControlModel{ServiceControl: "Works"})
@@ -63,10 +61,8 @@ func Migrate(db *gorm.DB) {
func runMigrations(db *gorm.DB) {
logging.Info("Running custom database migrations...")
// Migration 001: Password security upgrade
if err := migrations.RunPasswordSecurityMigration(db); err != nil {
logging.Error("Failed to run password security migration: %v", err)
// Continue - this migration might not be needed for all setups
}
logging.Info("Custom database migrations completed")
@@ -132,7 +128,6 @@ func seedCarModels(db *gorm.DB) error {
carModels := []model.CarModel{
{Value: 0, CarModel: "Porsche 991 GT3 R"},
{Value: 1, CarModel: "Mercedes-AMG GT3"},
// ... Add all car models from your list
}
for _, cm := range carModels {

View File

@@ -6,12 +6,10 @@ import (
)
const (
// Default paths for when environment variables are not set
DefaultSteamCMDPath = "c:\\steamcmd\\steamcmd.exe"
DefaultNSSMPath = ".\\nssm.exe"
)
// GetSteamCMDPath returns the SteamCMD executable path from environment variable or default
func GetSteamCMDPath() string {
if path := os.Getenv("STEAMCMD_PATH"); path != "" {
return path
@@ -19,13 +17,11 @@ func GetSteamCMDPath() string {
return DefaultSteamCMDPath
}
// GetSteamCMDDirPath returns the directory containing SteamCMD executable
func GetSteamCMDDirPath() string {
steamCMDPath := GetSteamCMDPath()
return filepath.Dir(steamCMDPath)
}
// GetNSSMPath returns the NSSM executable path from environment variable or default
func GetNSSMPath() string {
if path := os.Getenv("NSSM_PATH"); path != "" {
return path
@@ -33,17 +29,14 @@ func GetNSSMPath() string {
return DefaultNSSMPath
}
// ValidatePaths checks if the configured paths exist (optional validation)
func ValidatePaths() map[string]error {
errors := make(map[string]error)
// Check SteamCMD path
steamCMDPath := GetSteamCMDPath()
if _, err := os.Stat(steamCMDPath); os.IsNotExist(err) {
errors["STEAMCMD_PATH"] = err
}
// Check NSSM path
nssmPath := GetNSSMPath()
if _, err := os.Stat(nssmPath); os.IsNotExist(err) {
errors["NSSM_PATH"] = err

View File

@@ -9,45 +9,37 @@ import (
"github.com/gofiber/fiber/v2"
)
// ControllerErrorHandler provides centralized error handling for controllers
type ControllerErrorHandler struct {
errorLogger *logging.ErrorLogger
}
// NewControllerErrorHandler creates a new controller error handler instance
func NewControllerErrorHandler() *ControllerErrorHandler {
return &ControllerErrorHandler{
errorLogger: logging.GetErrorLogger(),
}
}
// ErrorResponse represents a standardized error response
type ErrorResponse struct {
Error string `json:"error"`
Code int `json:"code,omitempty"`
Details map[string]string `json:"details,omitempty"`
}
// HandleError handles controller errors with logging and standardized responses
func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
if err == nil {
return nil
}
// Get caller information for logging
_, file, line, _ := runtime.Caller(1)
file = strings.TrimPrefix(file, "acc-server-manager/")
// Build context string
contextStr := ""
if len(context) > 0 {
contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|"))
}
// Clean error message (remove null bytes)
cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "")
// Log the error with context
ceh.errorLogger.LogWithContext(
fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line),
"%s%s",
@@ -55,25 +47,39 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
cleanErrorMsg,
)
// Create standardized error response
errorResponse := ErrorResponse{
Error: cleanErrorMsg,
Code: statusCode,
}
// Add request details if available
if c != nil {
if errorResponse.Details == nil {
errorResponse.Details = make(map[string]string)
}
errorResponse.Details["method"] = c.Method()
errorResponse.Details["path"] = c.Path()
errorResponse.Details["ip"] = c.IP()
func() {
defer func() {
if r := recover(); r != nil {
return
}
}()
errorResponse.Details["method"] = c.Method()
errorResponse.Details["path"] = c.Path()
if ip := c.IP(); ip != "" {
errorResponse.Details["ip"] = ip
} else {
errorResponse.Details["ip"] = "unknown"
}
}()
}
if c == nil {
return fmt.Errorf("cannot return HTTP response: context is nil")
}
// Return appropriate response based on status code
if statusCode >= 500 {
// For server errors, don't expose internal details
return c.Status(statusCode).JSON(ErrorResponse{
Error: "Internal server error",
Code: statusCode,
@@ -83,52 +89,42 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
return c.Status(statusCode).JSON(errorResponse)
}
// HandleValidationError handles validation errors specifically
func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error {
return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field)
}
// HandleDatabaseError handles database-related errors
func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE")
}
// HandleAuthError handles authentication/authorization errors
func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH")
}
// HandleNotFoundError handles resource not found errors
func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error {
err := fmt.Errorf("%s not found", resource)
return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND")
}
// HandleBusinessLogicError handles business logic errors
func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC")
}
// HandleServiceError handles service layer errors
func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE")
}
// HandleParsingError handles request parsing errors
func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error {
return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING")
}
// HandleUUIDError handles UUID parsing errors
func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error {
err := fmt.Errorf("invalid %s format", field)
return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field)
}
// Global controller error handler instance
var globalErrorHandler *ControllerErrorHandler
// GetControllerErrorHandler returns the global controller error handler instance
func GetControllerErrorHandler() *ControllerErrorHandler {
if globalErrorHandler == nil {
globalErrorHandler = NewControllerErrorHandler()
@@ -136,49 +132,38 @@ func GetControllerErrorHandler() *ControllerErrorHandler {
return globalErrorHandler
}
// Convenience functions using the global error handler
// HandleError handles controller errors using the global error handler
func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
return GetControllerErrorHandler().HandleError(c, err, statusCode, context...)
}
// HandleValidationError handles validation errors using the global error handler
func HandleValidationError(c *fiber.Ctx, err error, field string) error {
return GetControllerErrorHandler().HandleValidationError(c, err, field)
}
// HandleDatabaseError handles database errors using the global error handler
func HandleDatabaseError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleDatabaseError(c, err)
}
// HandleAuthError handles auth errors using the global error handler
func HandleAuthError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleAuthError(c, err)
}
// HandleNotFoundError handles not found errors using the global error handler
func HandleNotFoundError(c *fiber.Ctx, resource string) error {
return GetControllerErrorHandler().HandleNotFoundError(c, resource)
}
// HandleBusinessLogicError handles business logic errors using the global error handler
func HandleBusinessLogicError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleBusinessLogicError(c, err)
}
// HandleServiceError handles service errors using the global error handler
func HandleServiceError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleServiceError(c, err)
}
// HandleParsingError handles parsing errors using the global error handler
func HandleParsingError(c *fiber.Ctx, err error) error {
return GetControllerErrorHandler().HandleParsingError(c, err)
}
// HandleUUIDError handles UUID errors using the global error handler
func HandleUUIDError(c *fiber.Ctx, field string) error {
return GetControllerErrorHandler().HandleUUIDError(c, field)
}

View File

@@ -0,0 +1,67 @@
package errors
import (
"acc-server-manager/local/utl/logging"
"fmt"
"os"
)
type SafeError struct {
Message string
Code int
Fatal bool
}
func (e *SafeError) Error() string {
return e.Message
}
func NewSafeError(message string, code int) *SafeError {
return &SafeError{
Message: message,
Code: code,
Fatal: false,
}
}
func NewFatalError(message string, code int) *SafeError {
return &SafeError{
Message: message,
Code: code,
Fatal: true,
}
}
func HandleError(err error, context string) {
if err == nil {
return
}
if safeErr, ok := err.(*SafeError); ok {
if safeErr.Fatal {
logging.Error("Fatal error in %s: %s", context, safeErr.Message)
if os.Getenv("ENVIRONMENT") == "production" {
logging.Error("Application shutting down due to fatal error")
os.Exit(safeErr.Code)
} else {
logging.Warn("Fatal error occurred but not exiting in non-production environment")
}
} else {
logging.Error("Error in %s: %s", context, safeErr.Message)
}
} else {
logging.Error("Unexpected error in %s: %v", context, err)
}
}
func SafeFatal(message string, args ...interface{}) {
formattedMessage := fmt.Sprintf(message, args...)
err := NewFatalError(formattedMessage, 1)
HandleError(err, "application")
}
func SafeLog(message string, args ...interface{}) {
formattedMessage := fmt.Sprintf(message, args...)
err := NewSafeError(formattedMessage, 0)
HandleError(err, "application")
}

View File

@@ -0,0 +1,91 @@
package graceful
import (
"context"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
type ShutdownManager struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
handlers []func() error
mutex sync.Mutex
}
var globalManager *ShutdownManager
var once sync.Once
func GetManager() *ShutdownManager {
once.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
globalManager = &ShutdownManager{
ctx: ctx,
cancel: cancel,
handlers: make([]func() error, 0),
}
go globalManager.watchSignals()
})
return globalManager
}
func (sm *ShutdownManager) watchSignals() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
sm.Shutdown(30 * time.Second)
}
func (sm *ShutdownManager) AddHandler(handler func() error) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
sm.handlers = append(sm.handlers, handler)
}
func (sm *ShutdownManager) Context() context.Context {
return sm.ctx
}
func (sm *ShutdownManager) AddGoroutine() {
sm.wg.Add(1)
}
func (sm *ShutdownManager) GoroutineDone() {
sm.wg.Done()
}
func (sm *ShutdownManager) RunGoroutine(fn func(ctx context.Context)) {
sm.wg.Add(1)
go func() {
defer sm.wg.Done()
fn(sm.ctx)
}()
}
func (sm *ShutdownManager) Shutdown(timeout time.Duration) {
sm.cancel()
done := make(chan struct{})
go func() {
sm.wg.Wait()
sm.mutex.Lock()
for _, handler := range sm.handlers {
handler()
}
sm.mutex.Unlock()
close(done)
}()
select {
case <-done:
case <-time.After(timeout):
}
}

View File

@@ -2,88 +2,100 @@ package jwt
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/errors"
"crypto/rand"
"encoding/base64"
"errors"
"log"
"os"
goerrors "errors"
"time"
"github.com/golang-jwt/jwt/v4"
)
// SecretKey holds the JWT signing key loaded from environment
var SecretKey []byte
// Claims represents the JWT claims.
type Claims struct {
UserID string `json:"user_id"`
UserID string `json:"user_id"`
IsOpenToken bool `json:"is_open_token"`
jwt.RegisteredClaims
}
// init initializes the JWT secret key from environment variable
func Init() {
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
log.Fatal("JWT_SECRET environment variable is required and cannot be empty")
}
type JWTHandler struct {
SecretKey []byte
IsOpenToken bool
}
// Decode base64 secret if it looks like base64, otherwise use as-is
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
SecretKey = decoded
} else {
SecretKey = []byte(jwtSecret)
}
type OpenJWTHandler struct {
*JWTHandler
}
// Ensure minimum key length for security
if len(SecretKey) < 32 {
log.Fatal("JWT_SECRET must be at least 32 bytes long for security")
func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
jwtHandler := NewJWTHandler(jwtSecret)
jwtHandler.IsOpenToken = true
return &OpenJWTHandler{
JWTHandler: jwtHandler,
}
}
// GenerateSecretKey generates a cryptographically secure random key for JWT signing
// This is a utility function for generating new secrets, not used in normal operation
func GenerateSecretKey() string {
func NewJWTHandler(jwtSecret string) *JWTHandler {
if jwtSecret == "" {
errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty")
}
var secretKey []byte
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
secretKey = decoded
} else {
secretKey = []byte(jwtSecret)
}
if len(secretKey) < 32 {
errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security")
}
return &JWTHandler{
SecretKey: secretKey,
}
}
func (jh *JWTHandler) GenerateSecretKey() string {
key := make([]byte, 64) // 512 bits
if _, err := rand.Read(key); err != nil {
log.Fatal("Failed to generate random key: ", err)
errors.SafeFatal("Failed to generate random key: %v", err)
}
return base64.StdEncoding.EncodeToString(key)
}
// GenerateToken generates a new JWT for a given user.
func GenerateToken(user *model.User) (string, error) {
func (jh *JWTHandler) GenerateToken(userId string) (string, error) {
expirationTime := time.Now().Add(24 * time.Hour)
claims := &Claims{
UserID: user.ID.String(),
UserID: userId,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
},
IsOpenToken: jh.IsOpenToken,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(SecretKey)
return token.SignedString(jh.SecretKey)
}
func GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error) {
func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error) {
expirationTime := expiry
claims := &Claims{
UserID: user.ID.String(),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expirationTime),
},
IsOpenToken: jh.IsOpenToken,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(SecretKey)
return token.SignedString(jh.SecretKey)
}
// ValidateToken validates a JWT and returns the claims if the token is valid.
func ValidateToken(tokenString string) (*Claims, error) {
func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return SecretKey, nil
return jh.SecretKey, nil
})
if err != nil {
@@ -91,7 +103,7 @@ func ValidateToken(tokenString string) (*Claims, error) {
}
if !token.Valid {
return nil, errors.New("invalid token")
return nil, goerrors.New("invalid token")
}
return claims, nil

View File

@@ -15,7 +15,6 @@ var (
timeFormat = "2006-01-02 15:04:05.000"
)
// BaseLogger provides the core logging functionality
type BaseLogger struct {
file *os.File
logger *log.Logger
@@ -23,7 +22,6 @@ type BaseLogger struct {
initialized bool
}
// LogLevel represents different logging levels
type LogLevel string
const (
@@ -34,28 +32,23 @@ const (
LogLevelPanic LogLevel = "PANIC"
)
// Initialize creates a new base logger instance
func InitializeBase(tp string) (*BaseLogger, error) {
return newBaseLogger(tp)
}
func newBaseLogger(tp string) (*BaseLogger, error) {
// Ensure logs directory exists
if err := os.MkdirAll("logs", 0755); err != nil {
return nil, fmt.Errorf("failed to create logs directory: %v", err)
}
// Open log file with date in name
logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp))
file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
if err != nil {
return nil, fmt.Errorf("failed to open log file: %v", err)
}
// Create multi-writer for both file and console
multiWriter := io.MultiWriter(file, os.Stdout)
// Create base logger
logger := &BaseLogger{
file: file,
logger: log.New(multiWriter, "", 0),
@@ -65,13 +58,11 @@ func newBaseLogger(tp string) (*BaseLogger, error) {
return logger, nil
}
// GetBaseLogger creates and returns a new base logger instance
func GetBaseLogger(tp string) *BaseLogger {
baseLogger, _ := InitializeBase(tp)
return baseLogger
}
// Close closes the log file
func (bl *BaseLogger) Close() error {
bl.mu.Lock()
defer bl.mu.Unlock()
@@ -82,7 +73,6 @@ func (bl *BaseLogger) Close() error {
return nil
}
// Log writes a log entry with the specified level
func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
if bl == nil || !bl.initialized {
return
@@ -91,14 +81,11 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
bl.mu.RLock()
defer bl.mu.RUnlock()
// Get caller info (skip 2 frames: this function and the calling Log function)
_, file, line, _ := runtime.Caller(2)
file = filepath.Base(file)
// Format message
msg := fmt.Sprintf(format, v...)
// Format final log line
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
time.Now().Format(timeFormat),
string(level),
@@ -110,7 +97,6 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
bl.logger.Println(logLine)
}
// LogWithCaller writes a log entry with custom caller depth
func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) {
if bl == nil || !bl.initialized {
return
@@ -119,14 +105,11 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
bl.mu.RLock()
defer bl.mu.RUnlock()
// Get caller info with custom depth
_, file, line, _ := runtime.Caller(callerDepth)
file = filepath.Base(file)
// Format message
msg := fmt.Sprintf(format, v...)
// Format final log line
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
time.Now().Format(timeFormat),
string(level),
@@ -138,7 +121,6 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
bl.logger.Println(logLine)
}
// IsInitialized returns whether the base logger is initialized
func (bl *BaseLogger) IsInitialized() bool {
if bl == nil {
return false
@@ -148,19 +130,16 @@ func (bl *BaseLogger) IsInitialized() bool {
return bl.initialized
}
// RecoverAndLog recovers from panics and logs them
func RecoverAndLog() {
baseLogger := GetBaseLogger("panic")
if baseLogger != nil && baseLogger.IsInitialized() {
if r := recover(); r != nil {
// Get stack trace
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
stackTrace := string(buf[:n])
baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace)
// Re-panic to maintain original behavior if needed
panic(r)
}
}

View File

@@ -6,12 +6,10 @@ import (
"sync"
)
// DebugLogger handles debug-level logging
type DebugLogger struct {
base *BaseLogger
}
// NewDebugLogger creates a new debug logger instance
func NewDebugLogger() *DebugLogger {
base, _ := InitializeBase("debug")
return &DebugLogger{
@@ -19,14 +17,12 @@ func NewDebugLogger() *DebugLogger {
}
}
// Log writes a debug-level log entry
func (dl *DebugLogger) Log(format string, v ...interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, format, v...)
}
}
// LogWithContext writes a debug-level log entry with additional context
func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) {
if dl.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -34,7 +30,6 @@ func (dl *DebugLogger) LogWithContext(context string, format string, v ...interf
}
}
// LogFunction logs function entry and exit for debugging
func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
if dl.base != nil {
if len(args) > 0 {
@@ -45,21 +40,18 @@ func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
}
}
// LogVariable logs variable values for debugging
func (dl *DebugLogger) LogVariable(varName string, value interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value)
}
}
// LogState logs application state information
func (dl *DebugLogger) LogState(component string, state interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state)
}
}
// LogSQL logs SQL queries for debugging
func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
if dl.base != nil {
if len(args) > 0 {
@@ -70,7 +62,6 @@ func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
}
}
// LogMemory logs memory usage information
func (dl *DebugLogger) LogMemory() {
if dl.base != nil {
var m runtime.MemStats
@@ -80,32 +71,27 @@ func (dl *DebugLogger) LogMemory() {
}
}
// LogGoroutines logs current number of goroutines
func (dl *DebugLogger) LogGoroutines() {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine())
}
}
// LogTiming logs timing information for performance debugging
func (dl *DebugLogger) LogTiming(operation string, duration interface{}) {
if dl.base != nil {
dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration)
}
}
// Helper function to convert bytes to kilobytes
func bToKb(b uint64) uint64 {
return b / 1024
}
// Global debug logger instance
var (
debugLogger *DebugLogger
debugOnce sync.Once
)
// GetDebugLogger returns the global debug logger instance
func GetDebugLogger() *DebugLogger {
debugOnce.Do(func() {
debugLogger = NewDebugLogger()
@@ -113,47 +99,38 @@ func GetDebugLogger() *DebugLogger {
return debugLogger
}
// Debug logs a debug-level message using the global debug logger
func Debug(format string, v ...interface{}) {
GetDebugLogger().Log(format, v...)
}
// DebugWithContext logs a debug-level message with context using the global debug logger
func DebugWithContext(context string, format string, v ...interface{}) {
GetDebugLogger().LogWithContext(context, format, v...)
}
// DebugFunction logs function entry and exit using the global debug logger
func DebugFunction(functionName string, args ...interface{}) {
GetDebugLogger().LogFunction(functionName, args...)
}
// DebugVariable logs variable values using the global debug logger
func DebugVariable(varName string, value interface{}) {
GetDebugLogger().LogVariable(varName, value)
}
// DebugState logs application state information using the global debug logger
func DebugState(component string, state interface{}) {
GetDebugLogger().LogState(component, state)
}
// DebugSQL logs SQL queries using the global debug logger
func DebugSQL(query string, args ...interface{}) {
GetDebugLogger().LogSQL(query, args...)
}
// DebugMemory logs memory usage information using the global debug logger
func DebugMemory() {
GetDebugLogger().LogMemory()
}
// DebugGoroutines logs current number of goroutines using the global debug logger
func DebugGoroutines() {
GetDebugLogger().LogGoroutines()
}
// DebugTiming logs timing information using the global debug logger
func DebugTiming(operation string, duration interface{}) {
GetDebugLogger().LogTiming(operation, duration)
}

View File

@@ -6,12 +6,10 @@ import (
"sync"
)
// ErrorLogger handles error-level logging
type ErrorLogger struct {
base *BaseLogger
}
// NewErrorLogger creates a new error logger instance
func NewErrorLogger() *ErrorLogger {
base, _ := InitializeBase("error")
return &ErrorLogger{
@@ -19,14 +17,12 @@ func NewErrorLogger() *ErrorLogger {
}
}
// Log writes an error-level log entry
func (el *ErrorLogger) Log(format string, v ...interface{}) {
if el.base != nil {
el.base.Log(LogLevelError, format, v...)
}
}
// LogWithContext writes an error-level log entry with additional context
func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) {
if el.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -34,7 +30,6 @@ func (el *ErrorLogger) LogWithContext(context string, format string, v ...interf
}
}
// LogError logs an error object with optional message
func (el *ErrorLogger) LogError(err error, message ...string) {
if el.base != nil && err != nil {
if len(message) > 0 {
@@ -45,7 +40,6 @@ func (el *ErrorLogger) LogError(err error, message ...string) {
}
}
// LogWithStackTrace logs an error with stack trace
func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
if el.base != nil {
// Get stack trace
@@ -58,7 +52,6 @@ func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
}
}
// LogFatal logs a fatal error and exits the program
func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
if el.base != nil {
el.base.Log(LogLevelError, "[FATAL] "+format, v...)
@@ -66,13 +59,11 @@ func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
}
}
// Global error logger instance
var (
errorLogger *ErrorLogger
errorOnce sync.Once
)
// GetErrorLogger returns the global error logger instance
func GetErrorLogger() *ErrorLogger {
errorOnce.Do(func() {
errorLogger = NewErrorLogger()
@@ -80,27 +71,22 @@ func GetErrorLogger() *ErrorLogger {
return errorLogger
}
// Error logs an error-level message using the global error logger
func Error(format string, v ...interface{}) {
GetErrorLogger().Log(format, v...)
}
// ErrorWithContext logs an error-level message with context using the global error logger
func ErrorWithContext(context string, format string, v ...interface{}) {
GetErrorLogger().LogWithContext(context, format, v...)
}
// LogError logs an error object using the global error logger
func LogError(err error, message ...string) {
GetErrorLogger().LogError(err, message...)
}
// ErrorWithStackTrace logs an error with stack trace using the global error logger
func ErrorWithStackTrace(format string, v ...interface{}) {
GetErrorLogger().LogWithStackTrace(format, v...)
}
// Fatal logs a fatal error and exits the program using the global error logger
func Fatal(format string, v ...interface{}) {
GetErrorLogger().LogFatal(format, v...)
}

View File

@@ -5,12 +5,10 @@ import (
"sync"
)
// InfoLogger handles info-level logging
type InfoLogger struct {
base *BaseLogger
}
// NewInfoLogger creates a new info logger instance
func NewInfoLogger() *InfoLogger {
base, _ := InitializeBase("info")
return &InfoLogger{
@@ -18,14 +16,12 @@ func NewInfoLogger() *InfoLogger {
}
}
// Log writes an info-level log entry
func (il *InfoLogger) Log(format string, v ...interface{}) {
if il.base != nil {
il.base.Log(LogLevelInfo, format, v...)
}
}
// LogWithContext writes an info-level log entry with additional context
func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) {
if il.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -33,55 +29,47 @@ func (il *InfoLogger) LogWithContext(context string, format string, v ...interfa
}
}
// LogStartup logs application startup information
func (il *InfoLogger) LogStartup(component string, message string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message)
}
}
// LogShutdown logs application shutdown information
func (il *InfoLogger) LogShutdown(component string, message string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message)
}
}
// LogOperation logs general operation information
func (il *InfoLogger) LogOperation(operation string, details string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details)
}
}
// LogStatus logs status changes or updates
func (il *InfoLogger) LogStatus(component string, status string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status)
}
}
// LogRequest logs incoming requests
func (il *InfoLogger) LogRequest(method string, path string, userAgent string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent)
}
}
// LogResponse logs outgoing responses
func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) {
if il.base != nil {
il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration)
}
}
// Global info logger instance
var (
infoLogger *InfoLogger
infoOnce sync.Once
)
// GetInfoLogger returns the global info logger instance
func GetInfoLogger() *InfoLogger {
infoOnce.Do(func() {
infoLogger = NewInfoLogger()
@@ -89,42 +77,34 @@ func GetInfoLogger() *InfoLogger {
return infoLogger
}
// Info logs an info-level message using the global info logger
func Info(format string, v ...interface{}) {
GetInfoLogger().Log(format, v...)
}
// InfoWithContext logs an info-level message with context using the global info logger
func InfoWithContext(context string, format string, v ...interface{}) {
GetInfoLogger().LogWithContext(context, format, v...)
}
// InfoStartup logs application startup information using the global info logger
func InfoStartup(component string, message string) {
GetInfoLogger().LogStartup(component, message)
}
// InfoShutdown logs application shutdown information using the global info logger
func InfoShutdown(component string, message string) {
GetInfoLogger().LogShutdown(component, message)
}
// InfoOperation logs general operation information using the global info logger
func InfoOperation(operation string, details string) {
GetInfoLogger().LogOperation(operation, details)
}
// InfoStatus logs status changes or updates using the global info logger
func InfoStatus(component string, status string) {
GetInfoLogger().LogStatus(component, status)
}
// InfoRequest logs incoming requests using the global info logger
func InfoRequest(method string, path string, userAgent string) {
GetInfoLogger().LogRequest(method, path, userAgent)
}
// InfoResponse logs outgoing responses using the global info logger
func InfoResponse(method string, path string, statusCode int, duration string) {
GetInfoLogger().LogResponse(method, path, statusCode, duration)
}

View File

@@ -6,12 +6,10 @@ import (
)
var (
// Legacy logger for backward compatibility
logger *Logger
once sync.Once
)
// Logger maintains backward compatibility with existing code
type Logger struct {
base *BaseLogger
errorLogger *ErrorLogger
@@ -20,8 +18,6 @@ type Logger struct {
debugLogger *DebugLogger
}
// Initialize creates or gets the singleton logger instance
// This maintains backward compatibility with existing code
func Initialize() (*Logger, error) {
var err error
once.Do(func() {
@@ -31,13 +27,11 @@ func Initialize() (*Logger, error) {
}
func newLogger() (*Logger, error) {
// Initialize the base logger
baseLogger, err := InitializeBase("log")
if err != nil {
return nil, err
}
// Create the legacy logger wrapper
logger := &Logger{
base: baseLogger,
errorLogger: GetErrorLogger(),
@@ -49,7 +43,6 @@ func newLogger() (*Logger, error) {
return logger, nil
}
// Close closes the logger
func (l *Logger) Close() error {
if l.base != nil {
return l.base.Close()
@@ -57,7 +50,6 @@ func (l *Logger) Close() error {
return nil
}
// Legacy methods for backward compatibility
func (l *Logger) log(level, format string, v ...interface{}) {
if l.base != nil {
l.base.LogWithCaller(LogLevel(level), 3, format, v...)
@@ -94,13 +86,10 @@ func (l *Logger) Panic(format string) {
}
}
// Global convenience functions for backward compatibility
// These are now implemented in individual logger files to avoid redeclaration
func LegacyInfo(format string, v ...interface{}) {
if logger != nil {
logger.Info(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetInfoLogger().Log(format, v...)
}
}
@@ -109,7 +98,6 @@ func LegacyError(format string, v ...interface{}) {
if logger != nil {
logger.Error(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetErrorLogger().Log(format, v...)
}
}
@@ -118,7 +106,6 @@ func LegacyWarn(format string, v ...interface{}) {
if logger != nil {
logger.Warn(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetWarnLogger().Log(format, v...)
}
}
@@ -127,7 +114,6 @@ func LegacyDebug(format string, v ...interface{}) {
if logger != nil {
logger.Debug(format, v...)
} else {
// Fallback to direct logger if legacy logger not initialized
GetDebugLogger().Log(format, v...)
}
}
@@ -136,55 +122,42 @@ func Panic(format string) {
if logger != nil {
logger.Panic(format)
} else {
// Fallback to direct logger if legacy logger not initialized
GetErrorLogger().LogFatal(format)
}
}
// Enhanced logging convenience functions
// These provide direct access to specialized logging functions
// LogStartup logs application startup information
func LogStartup(component string, message string) {
GetInfoLogger().LogStartup(component, message)
}
// LogShutdown logs application shutdown information
func LogShutdown(component string, message string) {
GetInfoLogger().LogShutdown(component, message)
}
// LogOperation logs general operation information
func LogOperation(operation string, details string) {
GetInfoLogger().LogOperation(operation, details)
}
// LogRequest logs incoming HTTP requests
func LogRequest(method string, path string, userAgent string) {
GetInfoLogger().LogRequest(method, path, userAgent)
}
// LogResponse logs outgoing HTTP responses
func LogResponse(method string, path string, statusCode int, duration string) {
GetInfoLogger().LogResponse(method, path, statusCode, duration)
}
// LogSQL logs SQL queries for debugging
func LogSQL(query string, args ...interface{}) {
GetDebugLogger().LogSQL(query, args...)
}
// LogMemory logs memory usage information
func LogMemory() {
GetDebugLogger().LogMemory()
}
// LogTiming logs timing information for performance debugging
func LogTiming(operation string, duration interface{}) {
GetDebugLogger().LogTiming(operation, duration)
}
// GetLegacyLogger returns the legacy logger instance for backward compatibility
func GetLegacyLogger() *Logger {
if logger == nil {
logger, _ = Initialize()
@@ -192,21 +165,17 @@ func GetLegacyLogger() *Logger {
return logger
}
// InitializeLogging initializes all logging components
func InitializeLogging() error {
// Initialize legacy logger for backward compatibility
_, err := Initialize()
if err != nil {
return fmt.Errorf("failed to initialize legacy logger: %v", err)
}
// Pre-initialize all logger types to ensure separate log files
GetErrorLogger()
GetWarnLogger()
GetInfoLogger()
GetDebugLogger()
// Log successful initialization
Info("Logging system initialized successfully")
return nil

View File

@@ -5,12 +5,10 @@ import (
"sync"
)
// WarnLogger handles warn-level logging
type WarnLogger struct {
base *BaseLogger
}
// NewWarnLogger creates a new warn logger instance
func NewWarnLogger() *WarnLogger {
base, _ := InitializeBase("warn")
return &WarnLogger{
@@ -18,14 +16,12 @@ func NewWarnLogger() *WarnLogger {
}
}
// Log writes a warn-level log entry
func (wl *WarnLogger) Log(format string, v ...interface{}) {
if wl.base != nil {
wl.base.Log(LogLevelWarn, format, v...)
}
}
// LogWithContext writes a warn-level log entry with additional context
func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) {
if wl.base != nil {
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
@@ -33,7 +29,6 @@ func (wl *WarnLogger) LogWithContext(context string, format string, v ...interfa
}
}
// LogDeprecation logs a deprecation warning
func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
if wl.base != nil {
if alternative != "" {
@@ -44,27 +39,23 @@ func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
}
}
// LogConfiguration logs configuration-related warnings
func (wl *WarnLogger) LogConfiguration(setting string, message string) {
if wl.base != nil {
wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message)
}
}
// LogPerformance logs performance-related warnings
func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) {
if wl.base != nil {
wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual)
}
}
// Global warn logger instance
var (
warnLogger *WarnLogger
warnOnce sync.Once
)
// GetWarnLogger returns the global warn logger instance
func GetWarnLogger() *WarnLogger {
warnOnce.Do(func() {
warnLogger = NewWarnLogger()
@@ -72,27 +63,22 @@ func GetWarnLogger() *WarnLogger {
return warnLogger
}
// Warn logs a warn-level message using the global warn logger
func Warn(format string, v ...interface{}) {
GetWarnLogger().Log(format, v...)
}
// WarnWithContext logs a warn-level message with context using the global warn logger
func WarnWithContext(context string, format string, v ...interface{}) {
GetWarnLogger().LogWithContext(context, format, v...)
}
// WarnDeprecation logs a deprecation warning using the global warn logger
func WarnDeprecation(feature string, alternative string) {
GetWarnLogger().LogDeprecation(feature, alternative)
}
// WarnConfiguration logs configuration-related warnings using the global warn logger
func WarnConfiguration(setting string, message string) {
GetWarnLogger().LogConfiguration(setting, message)
}
// WarnPerformance logs performance-related warnings using the global warn logger
func WarnPerformance(operation string, threshold string, actual string) {
GetWarnLogger().LogPerformance(operation, threshold, actual)
}

View File

@@ -6,12 +6,10 @@ import (
"time"
)
// IsPortAvailable checks if a port is available for both TCP and UDP
func IsPortAvailable(port int) bool {
return IsTCPPortAvailable(port) && IsUDPPortAvailable(port)
}
// IsTCPPortAvailable checks if a TCP port is available
func IsTCPPortAvailable(port int) bool {
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
@@ -22,7 +20,6 @@ func IsTCPPortAvailable(port int) bool {
return true
}
// IsUDPPortAvailable checks if a UDP port is available
func IsUDPPortAvailable(port int) bool {
conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
@@ -32,7 +29,6 @@ func IsUDPPortAvailable(port int) bool {
return true
}
// FindAvailablePort finds an available port starting from the given port
func FindAvailablePort(startPort int) (int, error) {
maxPort := 65535
for port := startPort; port <= maxPort; port++ {
@@ -43,14 +39,12 @@ func FindAvailablePort(startPort int) (int, error) {
return 0, fmt.Errorf("no available ports found between %d and %d", startPort, maxPort)
}
// FindAvailablePortRange finds a range of consecutive available ports
func FindAvailablePortRange(startPort, count int) ([]int, error) {
maxPort := 65535
ports := make([]int, 0, count)
currentPort := startPort
for len(ports) < count && currentPort <= maxPort {
// Check if we have enough consecutive ports available
available := true
for i := 0; i < count-len(ports); i++ {
if !IsPortAvailable(currentPort + i) {
@@ -74,7 +68,6 @@ func FindAvailablePortRange(startPort, count int) ([]int, error) {
return ports, nil
}
// WaitForPortAvailable waits for a port to become available with timeout
func WaitForPortAvailable(port int, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
@@ -84,4 +77,4 @@ func WaitForPortAvailable(port int, timeout time.Duration) error {
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timeout waiting for port %d to become available", port)
}
}

View File

@@ -8,13 +8,10 @@ import (
)
const (
// MinPasswordLength defines the minimum password length
MinPasswordLength = 8
// BcryptCost defines the cost factor for bcrypt hashing
BcryptCost = 12
BcryptCost = 12
)
// HashPassword hashes a plain text password using bcrypt
func HashPassword(password string) (string, error) {
if len(password) < MinPasswordLength {
return "", errors.New("password must be at least 8 characters long")
@@ -28,12 +25,10 @@ func HashPassword(password string) (string, error) {
return string(hashedBytes), nil
}
// VerifyPassword verifies a plain text password against a hashed password
func VerifyPassword(hashedPassword, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}
// ValidatePasswordStrength validates password complexity requirements
func ValidatePasswordStrength(password string) error {
if len(password) < MinPasswordLength {
return errors.New("password must be at least 8 characters long")

View File

@@ -0,0 +1,76 @@
package security
import (
"crypto/sha256"
"fmt"
"io"
"net/http"
"os"
"time"
)
type DownloadVerifier struct {
client *http.Client
}
func NewDownloadVerifier() *DownloadVerifier {
return &DownloadVerifier{
client: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
},
},
}
}
func (dv *DownloadVerifier) VerifyAndDownload(url, outputPath, expectedSHA256 string) error {
if url == "" {
return fmt.Errorf("URL cannot be empty")
}
if outputPath == "" {
return fmt.Errorf("output path cannot be empty")
}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "ACC-Server-Manager/1.0")
resp, err := dv.client.Do(req)
if err != nil {
return fmt.Errorf("failed to download: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
file, err := os.Create(outputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %v", err)
}
defer file.Close()
hash := sha256.New()
writer := io.MultiWriter(file, hash)
_, err = io.Copy(writer, resp.Body)
if err != nil {
os.Remove(outputPath)
return fmt.Errorf("failed to write file: %v", err)
}
if expectedSHA256 != "" {
actualHash := fmt.Sprintf("%x", hash.Sum(nil))
if actualHash != expectedSHA256 {
os.Remove(outputPath)
return fmt.Errorf("file hash mismatch: expected %s, got %s", expectedSHA256, actualHash)
}
}
return nil
}

View File

@@ -0,0 +1,94 @@
package security
import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
)
type PathValidator struct {
allowedBasePaths []string
blockedPatterns []*regexp.Regexp
}
func NewPathValidator() *PathValidator {
blockedPatterns := []*regexp.Regexp{
regexp.MustCompile(`\.\.`),
regexp.MustCompile(`^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])$`),
regexp.MustCompile(`\x00`),
regexp.MustCompile(`^\\\\`),
regexp.MustCompile(`^[a-zA-Z]:\\Windows`),
regexp.MustCompile(`^[a-zA-Z]:\\Program Files`),
}
return &PathValidator{
allowedBasePaths: []string{
`C:\ACC-Servers`,
`D:\ACC-Servers`,
`E:\ACC-Servers`,
`C:\SteamCMD`,
`D:\SteamCMD`,
`E:\SteamCMD`,
},
blockedPatterns: blockedPatterns,
}
}
func (pv *PathValidator) ValidateInstallPath(path string) error {
if path == "" {
return fmt.Errorf("path cannot be empty")
}
cleanPath := filepath.Clean(path)
absPath, err := filepath.Abs(cleanPath)
if err != nil {
return fmt.Errorf("invalid path: %v", err)
}
for _, pattern := range pv.blockedPatterns {
if pattern.MatchString(absPath) || pattern.MatchString(strings.ToUpper(filepath.Base(absPath))) {
return fmt.Errorf("path contains forbidden patterns")
}
}
allowed := false
for _, basePath := range pv.allowedBasePaths {
if strings.HasPrefix(strings.ToLower(absPath), strings.ToLower(basePath)) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("path must be within allowed directories: %v", pv.allowedBasePaths)
}
if len(absPath) > 260 {
return fmt.Errorf("path too long (max 260 characters)")
}
parentDir := filepath.Dir(absPath)
if parentInfo, err := os.Stat(parentDir); err == nil {
if !parentInfo.IsDir() {
return fmt.Errorf("parent path is not a directory")
}
}
return nil
}
func (pv *PathValidator) AddAllowedBasePath(path string) error {
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("invalid base path: %v", err)
}
pv.allowedBasePaths = append(pv.allowedBasePaths, absPath)
return nil
}
func (pv *PathValidator) GetAllowedBasePaths() []string {
return append([]string(nil), pv.allowedBasePaths...)
}

View File

@@ -17,30 +17,29 @@ import (
func Start(di *dig.Container) *fiber.App {
app := fiber.New(fiber.Config{
EnablePrintRoutes: true,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
BodyLimit: 10 * 1024 * 1024, // 10MB
ReadTimeout: 20 * time.Minute,
WriteTimeout: 20 * time.Minute,
IdleTimeout: 25 * time.Minute,
BodyLimit: 10 * 1024 * 1024,
})
// Initialize security middleware
securityMW := security.NewSecurityMiddleware()
// Add security middleware stack
app.Use(securityMW.SecurityHeaders())
app.Use(securityMW.LogSecurityEvents())
app.Use(securityMW.TimeoutMiddleware(30 * time.Second))
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
app.Use(securityMW.TimeoutMiddleware(20 * time.Minute))
app.Use(securityMW.RequestContextTimeout(20 * time.Minute))
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024))
app.Use(securityMW.ValidateUserAgent())
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
app.Use(securityMW.InputSanitization())
app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global
app.Use(securityMW.RateLimit(100, 1*time.Minute))
app.Use(helmet.New())
allowedOrigin := os.Getenv("CORS_ALLOWED_ORIGIN")
if allowedOrigin == "" {
allowedOrigin = "http://localhost:5173"
allowedOrigin = "http://localhost:3000"
}
app.Use(cors.New(cors.Config{
@@ -61,7 +60,7 @@ func Start(di *dig.Container) *fiber.App {
port := os.Getenv("PORT")
if port == "" {
port = "3000" // Default port
port = "3000"
}
logging.Info("Starting server on port %s", port)

View File

@@ -7,11 +7,11 @@ import (
)
type LogTailer struct {
filePath string
handleLine func(string)
stopChan chan struct{}
isRunning bool
tracker *PositionTracker
filePath string
handleLine func(string)
stopChan chan struct{}
isRunning bool
tracker *PositionTracker
}
func NewLogTailer(filePath string, handleLine func(string)) *LogTailer {
@@ -30,10 +30,9 @@ func (t *LogTailer) Start() {
t.isRunning = true
go func() {
// Load last position from tracker
pos, err := t.tracker.LoadPosition()
if err != nil {
pos = &LogPosition{} // Start from beginning if error
pos = &LogPosition{}
}
lastSize := pos.LastPosition
@@ -43,7 +42,6 @@ func (t *LogTailer) Start() {
t.isRunning = false
return
default:
// Try to open and read the file
if file, err := os.Open(t.filePath); err == nil {
stat, err := file.Stat()
if err != nil {
@@ -52,12 +50,10 @@ func (t *LogTailer) Start() {
continue
}
// If file was truncated, start from beginning
if stat.Size() < lastSize {
lastSize = 0
}
// Seek to last read position
if lastSize > 0 {
file.Seek(lastSize, 0)
}
@@ -66,9 +62,8 @@ func (t *LogTailer) Start() {
for scanner.Scan() {
line := scanner.Text()
t.handleLine(line)
lastSize, _ = file.Seek(0, 1) // Get current position
// Save position periodically
lastSize, _ = file.Seek(0, 1)
t.tracker.SavePosition(&LogPosition{
LastPosition: lastSize,
LastRead: line,
@@ -78,7 +73,6 @@ func (t *LogTailer) Start() {
file.Close()
}
// Wait before next attempt
time.Sleep(time.Second)
}
}
@@ -90,4 +84,4 @@ func (t *LogTailer) Stop() {
return
}
close(t.stopChan)
}
}

View File

@@ -7,8 +7,8 @@ import (
)
type LogPosition struct {
LastPosition int64 `json:"last_position"`
LastRead string `json:"last_read"`
LastPosition int64 `json:"last_position"`
LastRead string `json:"last_read"`
}
type PositionTracker struct {
@@ -16,11 +16,10 @@ type PositionTracker struct {
}
func NewPositionTracker(logPath string) *PositionTracker {
// Create position file in same directory as log file
dir := filepath.Dir(logPath)
base := filepath.Base(logPath)
positionFile := filepath.Join(dir, "."+base+".position")
return &PositionTracker{
positionFile: positionFile,
}
@@ -30,7 +29,6 @@ func (t *PositionTracker) LoadPosition() (*LogPosition, error) {
data, err := os.ReadFile(t.positionFile)
if err != nil {
if os.IsNotExist(err) {
// Return empty position if file doesn't exist
return &LogPosition{}, nil
}
return nil, err
@@ -51,4 +49,4 @@ func (t *PositionTracker) SavePosition(pos *LogPosition) error {
}
return os.WriteFile(t.positionFile, data, 0644)
}
}

View File

@@ -11,102 +11,104 @@ import (
)
type StateChange int
const (
PlayerCount StateChange = iota
Session
PlayerCount StateChange = iota
Session
)
var StateChanges = map[StateChange]string {
PlayerCount: "player-count",
Session: "session",
var StateChanges = map[StateChange]string{
PlayerCount: "player-count",
Session: "session",
}
type AccServerInstance struct {
Model *model.Server
State *model.ServerState
OnStateChange func(*model.ServerState, ...StateChange)
Model *model.Server
State *model.ServerState
OnStateChange func(*model.ServerState, ...StateChange)
}
func NewAccServerInstance(server *model.Server, onStateChange func(*model.ServerState, ...StateChange)) *AccServerInstance {
return &AccServerInstance{
Model: server,
State: &model.ServerState{PlayerCount: 0},
OnStateChange: onStateChange,
}
return &AccServerInstance{
Model: server,
State: &model.ServerState{PlayerCount: 0},
OnStateChange: onStateChange,
}
}
type StateRegexHandler struct {
*regex_handler.RegexHandler
test string
*regex_handler.RegexHandler
test string
}
func NewRegexHandler(str string, test string) *StateRegexHandler {
return &StateRegexHandler{
RegexHandler: regex_handler.New(str),
test: test,
RegexHandler: regex_handler.New(str),
test: test,
}
}
func (rh *StateRegexHandler) Test(line string) bool{
return strings.Contains(line, rh.test)
func (rh *StateRegexHandler) Test(line string) bool {
return strings.Contains(line, rh.test)
}
func (rh *StateRegexHandler) Count(line string) int{
var count int = 0
rh.Contains(line, func (strs ...string) {
if len(strs) == 2 {
if ct, err := strconv.Atoi(strs[1]); err == nil {
count = ct
}
}
})
return count
func (rh *StateRegexHandler) Count(line string) int {
var count int = 0
rh.Contains(line, func(strs ...string) {
if len(strs) == 2 {
if ct, err := strconv.Atoi(strs[1]); err == nil {
count = ct
}
}
})
return count
}
func (rh *StateRegexHandler) Change(line string) (string, string){
var old string = ""
var new string = ""
rh.Contains(line, func (strs ...string) {
if len(strs) == 3 {
old = strs[1]
new = strs[2]
}
})
return old, new
func (rh *StateRegexHandler) Change(line string) (string, string) {
var old string = ""
var new string = ""
rh.Contains(line, func(strs ...string) {
if len(strs) == 3 {
old = strs[1]
new = strs[2]
}
})
return old, new
}
func TailLogFile(path string, callback func(string)) {
file, _ := os.Open(path)
defer file.Close()
file, _ := os.Open(path)
defer file.Close()
file.Seek(0, os.SEEK_END) // Start at end of file
reader := bufio.NewReader(file)
file.Seek(0, os.SEEK_END)
reader := bufio.NewReader(file)
for {
line, err := reader.ReadString('\n')
if err == nil {
callback(line)
} else {
time.Sleep(500 * time.Millisecond) // wait for new data
}
}
for {
line, err := reader.ReadString('\n')
if err == nil {
callback(line)
} else {
time.Sleep(500 * time.Millisecond)
}
}
}
type LogStateType int
const (
SessionChange LogStateType = iota
LeaderboardUpdate
UDPCount
ClientsOnline
RemovingDeadConnection
SessionChange LogStateType = iota
LeaderboardUpdate
UDPCount
ClientsOnline
RemovingDeadConnection
)
var logStateContain = map[LogStateType]string {
SessionChange: "Session changed",
LeaderboardUpdate: "Updated leaderboard for",
UDPCount: "Udp message count",
ClientsOnline: "client(s) online",
RemovingDeadConnection: "Removing dead connection",
var logStateContain = map[LogStateType]string{
SessionChange: "Session changed",
LeaderboardUpdate: "Updated leaderboard for",
UDPCount: "Udp message count",
ClientsOnline: "client(s) online",
RemovingDeadConnection: "Removing dead connection",
}
var sessionChangeRegex = NewRegexHandler(`Session changed: (\w+) -> (\w+)`, logStateContain[SessionChange])
@@ -115,75 +117,77 @@ var udpCountRegex = NewRegexHandler(`Udp message count (\d+) client`, logStateCo
var clientsOnlineRegex = NewRegexHandler(`(\d+) client\(s\) online`, logStateContain[ClientsOnline])
var removingDeadConnectionsRegex = NewRegexHandler(`Removing dead connection`, logStateContain[RemovingDeadConnection])
var logStateRegex = map[LogStateType]*StateRegexHandler {
SessionChange: sessionChangeRegex,
LeaderboardUpdate: leaderboardUpdateRegex,
UDPCount: udpCountRegex,
ClientsOnline: clientsOnlineRegex,
RemovingDeadConnection: removingDeadConnectionsRegex,
var logStateRegex = map[LogStateType]*StateRegexHandler{
SessionChange: sessionChangeRegex,
LeaderboardUpdate: leaderboardUpdateRegex,
UDPCount: udpCountRegex,
ClientsOnline: clientsOnlineRegex,
RemovingDeadConnection: removingDeadConnectionsRegex,
}
func (instance *AccServerInstance) HandleLogLine(line string) {
for logState, regexHandler := range logStateRegex {
if (regexHandler.Test(line)) {
switch logState {
case LeaderboardUpdate:
case UDPCount:
case ClientsOnline:
count := regexHandler.Count(line)
instance.UpdatePlayerCount(count)
case SessionChange:
_, new := regexHandler.Change(line)
instance.UpdateSessionChange(new)
case RemovingDeadConnection:
instance.UpdatePlayerCount(instance.State.PlayerCount - 1)
}
}
}
for logState, regexHandler := range logStateRegex {
if regexHandler.Test(line) {
switch logState {
case LeaderboardUpdate:
case UDPCount:
case ClientsOnline:
count := regexHandler.Count(line)
instance.UpdatePlayerCount(count)
case SessionChange:
_, new := regexHandler.Change(line)
trackSession := model.ToTrackSession(new)
instance.UpdateSessionChange(trackSession)
case RemovingDeadConnection:
instance.UpdatePlayerCount(instance.State.PlayerCount - 1)
}
}
}
}
func (instance *AccServerInstance) UpdateState(callback func(state *model.ServerState, changes *[]StateChange)) {
state := instance.State
changes := []StateChange{}
state.Lock()
defer state.Unlock()
callback(state, &changes)
if (len(changes) > 0) {
instance.OnStateChange(state, changes...)
}
state := instance.State
changes := []StateChange{}
state.Lock()
defer state.Unlock()
callback(state, &changes)
if len(changes) > 0 {
instance.OnStateChange(state, changes...)
}
}
func (instance *AccServerInstance) UpdatePlayerCount(count int) {
if (count < 0) {
return
}
instance.UpdateState(func (state *model.ServerState, changes *[]StateChange) {
if (count == state.PlayerCount) {
return
}
if (count > 0 && state.PlayerCount == 0) {
state.SessionStart = time.Now()
*changes = append(*changes, Session)
} else if (count == 0) {
state.SessionStart = time.Time{}
*changes = append(*changes, Session)
}
state.PlayerCount = count
*changes = append(*changes, PlayerCount)
})
if count < 0 {
return
}
instance.UpdateState(func(state *model.ServerState, changes *[]StateChange) {
if count == state.PlayerCount {
return
}
if count > 0 && state.PlayerCount == 0 {
state.SessionStart = time.Now()
*changes = append(*changes, Session)
} else if count == 0 {
state.SessionStart = time.Time{}
*changes = append(*changes, Session)
}
state.PlayerCount = count
*changes = append(*changes, PlayerCount)
})
}
func (instance *AccServerInstance) UpdateSessionChange(session string) {
instance.UpdateState(func (state *model.ServerState, changes *[]StateChange) {
if (session == state.Session) {
return
}
if (state.PlayerCount > 0) {
state.SessionStart = time.Now()
} else {
state.SessionStart = time.Time{}
}
state.Session = session
*changes = append(*changes, Session)
})
}
func (instance *AccServerInstance) UpdateSessionChange(session model.TrackSession) {
instance.UpdateState(func(state *model.ServerState, changes *[]StateChange) {
if session == state.Session {
return
}
if state.PlayerCount > 0 {
state.SessionStart = time.Now()
} else {
state.SessionStart = time.Time{}
}
state.Session = session
*changes = append(*changes, Session)
})
}

View File

@@ -0,0 +1,169 @@
package websocket
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/logging"
"encoding/json"
"sync"
"time"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid"
)
type WebSocketConnection struct {
conn *websocket.Conn
serverID *uuid.UUID
userID *uuid.UUID
}
type WebSocketService struct {
connections sync.Map
mu sync.RWMutex
}
func NewWebSocketService() *WebSocketService {
return &WebSocketService{}
}
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
wsConn := &WebSocketConnection{
conn: conn,
userID: userID,
}
ws.connections.Store(connID, wsConn)
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
}
func (ws *WebSocketService) RemoveConnection(connID string) {
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
wsConn.conn.Close()
}
}
logging.Info("WebSocket connection removed: %s", connID)
}
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
if conn, exists := ws.connections.Load(connID); exists {
if wsConn, ok := conn.(*WebSocketConnection); ok {
wsConn.serverID = &serverID
}
}
}
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
stepMsg := model.StepMessage{
Step: step,
Status: status,
Message: message,
Error: errorMsg,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeStep,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: stepMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
steamMsg := model.SteamOutputMessage{
Output: output,
IsError: isError,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeSteamOutput,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: steamMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
errorMsg := model.ErrorMessage{
Error: error,
Details: details,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeError,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: errorMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
completeMsg := model.CompleteMessage{
ServerID: serverID,
Success: success,
Message: message,
}
wsMsg := model.WebSocketMessage{
Type: model.MessageTypeComplete,
ServerID: &serverID,
Timestamp: time.Now().Unix(),
Data: completeMsg,
}
ws.broadcastToServer(serverID, wsMsg)
}
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
logging.Error("Failed to marshal WebSocket message: %v", err)
return
}
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if wsConn.serverID != nil && *wsConn.serverID == serverID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
}
return true
})
}
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
data, err := json.Marshal(message)
if err != nil {
logging.Error("Failed to marshal WebSocket message: %v", err)
return
}
ws.connections.Range(func(key, value interface{}) bool {
if wsConn, ok := value.(*WebSocketConnection); ok {
if wsConn.userID != nil && *wsConn.userID == userID {
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
ws.RemoveConnection(key.(string))
}
}
}
return true
})
}
func (ws *WebSocketService) GetActiveConnections() int {
count := 0
ws.connections.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
basePath: /api/v1
basePath: /v1
definitions:
error_handler.ErrorResponse:
properties:
@@ -92,6 +92,35 @@ definitions:
- StatusRestarting
- StatusStarting
- StatusRunning
model.Steam2FARequest:
properties:
completedAt:
type: string
errorMsg:
type: string
id:
type: string
message:
type: string
requestTime:
type: string
serverId:
type: string
status:
$ref: '#/definitions/model.Steam2FAStatus'
type: object
model.Steam2FAStatus:
enum:
- idle
- pending
- complete
- error
type: string
x-enum-varnames:
- Steam2FAStatusIdle
- Steam2FAStatusPending
- Steam2FAStatusComplete
- Steam2FAStatusError
model.User:
properties:
id:
@@ -103,11 +132,20 @@ definitions:
username:
type: string
type: object
host: localhost:3000
service.UpdateUserRequest:
properties:
password:
type: string
roleId:
type: string
username:
type: string
type: object
host: acc-api.jurmanovic.com
info:
contact:
name: ACC Server Manager Support
url: https://github.com/yourusername/acc-server-manager
url: https://github.com/FJurmanovic/acc-server-manager
description: API for managing Assetto Corsa Competizione dedicated servers
license:
name: MIT
@@ -115,6 +153,62 @@ info:
title: ACC Server Manager API
version: "1.0"
paths:
/api/server:
get:
consumes:
- application/json
description: Get a list of all ACC servers with filtering options
parameters:
- in: query
name: name
type: string
- in: query
name: page
type: integer
- in: query
name: pageSize
type: integer
- in: query
name: serverID
type: string
- in: query
name: serviceName
type: string
- in: query
name: sortBy
type: string
- in: query
name: sortDesc
type: boolean
- in: query
name: status
type: string
produces:
- application/json
responses:
"200":
description: List of servers
schema:
items:
$ref: '#/definitions/model.ServerAPI'
type: array
"400":
description: Invalid filter parameters
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: List all servers (API format)
tags:
- Server
/auth/login:
post:
consumes:
@@ -157,6 +251,228 @@ paths:
summary: User login
tags:
- Authentication
/auth/me:
get:
consumes:
- application/json
description: Get details of the currently authenticated user
produces:
- application/json
responses:
"200":
description: Current user details
schema:
$ref: '#/definitions/model.User'
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: User not found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get current user details
tags:
- Authentication
/auth/open-token:
post:
consumes:
- application/json
description: Generate an open token for a user
produces:
- application/json
responses:
"200":
description: JWT token
schema:
properties:
token:
type: string
type: object
"400":
description: Invalid request body
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"401":
description: Invalid credentials
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
summary: Generate an open token
tags:
- Authentication
/lookup/car-models:
get:
consumes:
- application/json
description: Get a list of all available ACC car models with their identifiers
produces:
- application/json
responses:
"200":
description: List of car models
schema:
items:
properties:
class:
type: string
id:
type: string
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get available car models
tags:
- Lookups
/lookup/cup-categories:
get:
consumes:
- application/json
description: Get a list of all available racing cup categories
produces:
- application/json
responses:
"200":
description: List of cup categories
schema:
items:
properties:
id:
type: number
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get cup categories
tags:
- Lookups
/lookup/driver-categories:
get:
consumes:
- application/json
description: Get a list of all driver categories (Bronze, Silver, Gold, Platinum)
produces:
- application/json
responses:
"200":
description: List of driver categories
schema:
items:
properties:
description:
type: string
id:
type: number
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get driver categories
tags:
- Lookups
/lookup/session-types:
get:
consumes:
- application/json
description: Get a list of all available session types (Practice, Qualifying,
Race)
produces:
- application/json
responses:
"200":
description: List of session types
schema:
items:
properties:
code:
type: string
id:
type: string
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get session types
tags:
- Lookups
/lookup/tracks:
get:
consumes:
- application/json
description: Get a list of all available ACC tracks with their identifiers
produces:
- application/json
responses:
"200":
description: List of tracks
schema:
items:
properties:
id:
type: string
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get available tracks
tags:
- Lookups
/membership:
get:
consumes:
@@ -238,117 +554,141 @@ paths:
summary: Create a new user
tags:
- User Management
/server:
get:
/membership/{id}:
delete:
consumes:
- application/json
description: Get a list of all ACC servers with filtering options
description: Delete a specific user by ID
parameters:
- in: query
name: name
type: string
- in: query
name: page
type: integer
- in: query
name: pageSize
type: integer
- in: query
name: serverID
type: string
- in: query
name: serviceName
type: string
- in: query
name: sortBy
type: string
- in: query
name: sortDesc
type: boolean
- in: query
name: status
- description: User ID (UUID format)
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: List of servers
schema:
items:
$ref: '#/definitions/model.ServerAPI'
type: array
"204":
description: User successfully deleted
"400":
description: Invalid filter parameters
description: Invalid user ID format
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
"403":
description: Insufficient permissions
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: User not found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: List all servers (API format)
summary: Delete user
tags:
- Server
/v1/lookup/car-models:
- User Management
get:
consumes:
- application/json
description: Get a list of all available ACC car models with their identifiers
description: Get detailed information about a specific user
parameters:
- description: User ID (UUID format)
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: List of car models
description: User details
schema:
$ref: '#/definitions/model.User'
"400":
description: Invalid user ID format
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: User not found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get user by ID
tags:
- User Management
put:
consumes:
- application/json
description: Update user details by ID
parameters:
- description: User ID (UUID format)
in: path
name: id
required: true
type: string
- description: Updated user details
in: body
name: user
required: true
schema:
$ref: '#/definitions/service.UpdateUserRequest'
produces:
- application/json
responses:
"200":
description: Updated user details
schema:
$ref: '#/definitions/model.User'
"400":
description: Invalid request body or ID format
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"403":
description: Insufficient permissions
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: User not found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Update user
tags:
- User Management
/membership/roles:
get:
consumes:
- application/json
description: Get a list of all available user roles
produces:
- application/json
responses:
"200":
description: List of roles
schema:
items:
properties:
class:
type: string
id:
type: string
name:
type: string
type: object
$ref: '#/definitions/model.Role'
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get available car models
tags:
- Lookups
/v1/lookup/cup-categories:
get:
consumes:
- application/json
description: Get a list of all available racing cup categories
produces:
- application/json
responses:
"200":
description: List of cup categories
schema:
items:
properties:
id:
type: number
name:
type: string
type: object
type: array
"401":
description: Unauthorized
"403":
description: Insufficient permissions
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
@@ -357,111 +697,10 @@ paths:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get cup categories
summary: Get all roles
tags:
- Lookups
/v1/lookup/driver-categories:
get:
consumes:
- application/json
description: Get a list of all driver categories (Bronze, Silver, Gold, Platinum)
produces:
- application/json
responses:
"200":
description: List of driver categories
schema:
items:
properties:
description:
type: string
id:
type: number
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get driver categories
tags:
- Lookups
/v1/lookup/session-types:
get:
consumes:
- application/json
description: Get a list of all available session types (Practice, Qualifying,
Race)
produces:
- application/json
responses:
"200":
description: List of session types
schema:
items:
properties:
code:
type: string
id:
type: string
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get session types
tags:
- Lookups
/v1/lookup/tracks:
get:
consumes:
- application/json
description: Get a list of all available ACC tracks with their identifiers
produces:
- application/json
responses:
"200":
description: List of tracks
schema:
items:
properties:
id:
type: string
name:
type: string
type: object
type: array
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Get available tracks
tags:
- Lookups
/v1/server:
- User Management
/server:
get:
consumes:
- application/json
@@ -556,7 +795,49 @@ paths:
summary: Create a new ACC server
tags:
- Server
/v1/server/{id}:
/server/{id}:
delete:
consumes:
- application/json
description: Delete an existing ACC server
parameters:
- description: Server ID (UUID format)
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: Deleted server details
schema:
type: object
"400":
description: Invalid server data or ID
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"401":
description: Unauthorized
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"403":
description: Insufficient permissions
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: Server not found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal server error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
security:
- BearerAuth: []
summary: Delete an ACC server
tags:
- Server
get:
consumes:
- application/json
@@ -643,7 +924,7 @@ paths:
summary: Update an ACC server
tags:
- Server
/v1/server/{id}/config:
/server/{id}/config:
get:
consumes:
- application/json
@@ -688,7 +969,7 @@ paths:
summary: List available configuration files
tags:
- Server Configuration
/v1/server/{id}/config/{file}:
/server/{id}/config/{file}:
get:
consumes:
- application/json
@@ -792,20 +1073,7 @@ paths:
summary: Update server configuration file
tags:
- Server Configuration
/v1/service-control:
get:
description: Return service control status
responses:
"200":
description: OK
schema:
items:
type: string
type: array
summary: Return service control status
tags:
- service-control
/v1/service-control/{service}:
/server/{id}/service/{service}:
get:
consumes:
- application/json
@@ -849,7 +1117,7 @@ paths:
summary: Get service status
tags:
- Service Control
/v1/service-control/restart:
/server/{id}/service/restart:
post:
consumes:
- application/json
@@ -899,7 +1167,7 @@ paths:
summary: Restart a Windows service
tags:
- Service Control
/v1/service-control/start:
/server/{id}/service/start:
post:
consumes:
- application/json
@@ -953,7 +1221,7 @@ paths:
summary: Start a Windows service
tags:
- Service Control
/v1/service-control/stop:
/server/{id}/service/stop:
post:
consumes:
- application/json
@@ -1007,7 +1275,7 @@ paths:
summary: Stop a Windows service
tags:
- Service Control
/v1/state-history:
/state-history:
get:
description: Return StateHistorys
responses:
@@ -1020,7 +1288,7 @@ paths:
summary: Return StateHistorys
tags:
- StateHistory
/v1/state-history/statistics:
/state-history/statistics:
get:
description: Return StateHistorys
responses:
@@ -1033,8 +1301,136 @@ paths:
summary: Return StateHistorys
tags:
- StateHistory
/steam2fa/{id}:
get:
consumes:
- application/json
description: Get a specific Steam 2FA authentication request by ID
parameters:
- description: 2FA Request ID
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/model.Steam2FARequest'
"404":
description: Not Found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
summary: Get 2FA request
tags:
- Steam 2FA
/steam2fa/{id}/cancel:
post:
consumes:
- application/json
description: Cancel a Steam 2FA authentication request
parameters:
- description: 2FA Request ID
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/model.Steam2FARequest'
"400":
description: Bad Request
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: Not Found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
summary: Cancel 2FA request
tags:
- Steam 2FA
/steam2fa/{id}/complete:
post:
consumes:
- application/json
description: Mark a Steam 2FA authentication request as completed
parameters:
- description: 2FA Request ID
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/model.Steam2FARequest'
"400":
description: Bad Request
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"404":
description: Not Found
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
summary: Complete 2FA request
tags:
- Steam 2FA
/steam2fa/pending:
get:
consumes:
- application/json
description: Get all pending Steam 2FA authentication requests
produces:
- application/json
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/model.Steam2FARequest'
type: array
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/error_handler.ErrorResponse'
summary: Get pending 2FA requests
tags:
- Steam 2FA
/system/health:
get:
description: Return service control status
responses:
"200":
description: OK
schema:
items:
type: string
type: array
summary: Return service control status
tags:
- system
schemes:
- http
- https
securityDefinitions:
BearerAuth:

View File

@@ -4,22 +4,26 @@ import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/jwt"
"fmt"
"os"
"time"
"github.com/google/uuid"
)
// GenerateTestToken creates a JWT token for testing purposes
func GenerateTestToken() (string, error) {
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "test_user",
RoleID: uuid.New(),
}
// Generate JWT token
token, err := jwt.GenerateToken(user)
testSecret := os.Getenv("JWT_SECRET")
if testSecret == "" {
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(testSecret)
token, err := jwtHandler.GenerateToken(user.ID.String())
if err != nil {
return "", fmt.Errorf("failed to generate test token: %w", err)
}
@@ -27,8 +31,6 @@ func GenerateTestToken() (string, error) {
return token, nil
}
// MustGenerateTestToken generates a test token and panics if it fails
// This is useful for test setup where failing to generate a token is a fatal error
func MustGenerateTestToken() string {
token, err := GenerateTestToken()
if err != nil {
@@ -37,17 +39,20 @@ func MustGenerateTestToken() string {
return token
}
// GenerateTestTokenWithExpiry creates a JWT token with a specific expiry time
func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
// Create test user
testSecret := os.Getenv("JWT_SECRET")
if testSecret == "" {
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(testSecret)
user := &model.User{
ID: uuid.New(),
Username: "test_user",
RoleID: uuid.New(),
}
// Generate JWT token with custom expiry
token, err := jwt.GenerateTokenWithExpiry(user, expiryTime)
token, err := jwtHandler.GenerateTokenWithExpiry(user, expiryTime)
if err != nil {
return "", fmt.Errorf("failed to generate test token with expiry: %w", err)
}
@@ -55,8 +60,6 @@ func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
return token, nil
}
// AddAuthHeader adds a test auth token to the request headers
// This is a convenience method for tests that need to authenticate requests
func AddAuthHeader(headers map[string]string) (map[string]string, error) {
token, err := GenerateTestToken()
if err != nil {
@@ -71,7 +74,6 @@ func AddAuthHeader(headers map[string]string) (map[string]string, error) {
return headers, nil
}
// MustAddAuthHeader adds a test auth token to the request headers and panics if it fails
func MustAddAuthHeader(headers map[string]string) map[string]string {
result, err := AddAuthHeader(headers)
if err != nil {

View File

@@ -8,26 +8,20 @@ import (
"github.com/google/uuid"
)
// MockAuthMiddleware provides a test implementation of AuthMiddleware
// that can be used as a drop-in replacement for the real AuthMiddleware
type MockAuthMiddleware struct{}
// NewMockAuthMiddleware creates a new MockAuthMiddleware
func NewMockAuthMiddleware() *MockAuthMiddleware {
return &MockAuthMiddleware{}
}
// Authenticate is a middleware that allows all requests without authentication for testing
func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
// Set a mock user ID in context
mockUserID := uuid.New().String()
ctx.Locals("userID", mockUserID)
// Set mock user info
mockUserInfo := &middleware.CachedUserInfo{
UserID: mockUserID,
Username: "test_user",
RoleName: "Admin", // Admin role to bypass permission checks
RoleName: "Admin",
Permissions: map[string]bool{"*": true},
CachedAt: time.Now(),
}
@@ -38,21 +32,18 @@ func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
return ctx.Next()
}
// HasPermission is a middleware that allows all permission checks to pass for testing
func (m *MockAuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.Next()
}
}
// AuthRateLimit is a test implementation that allows all requests
func (m *MockAuthMiddleware) AuthRateLimit() fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.Next()
}
}
// RequireHTTPS is a test implementation that allows all HTTP requests
func (m *MockAuthMiddleware) RequireHTTPS() fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.Next()

View File

@@ -8,7 +8,6 @@ import (
"github.com/google/uuid"
)
// MockConfigRepository provides a mock implementation of ConfigRepository
type MockConfigRepository struct {
configs map[string]*model.Config
shouldFailGet bool
@@ -21,7 +20,6 @@ func NewMockConfigRepository() *MockConfigRepository {
}
}
// UpdateConfig mocks the UpdateConfig method
func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
if m.shouldFailUpdate {
return nil
@@ -36,18 +34,15 @@ func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.C
return config
}
// SetShouldFailUpdate configures the mock to fail on UpdateConfig calls
func (m *MockConfigRepository) SetShouldFailUpdate(shouldFail bool) {
m.shouldFailUpdate = shouldFail
}
// GetConfig retrieves a config by server ID and config file
func (m *MockConfigRepository) GetConfig(serverID uuid.UUID, configFile string) *model.Config {
key := serverID.String() + "_" + configFile
return m.configs[key]
}
// MockServerRepository provides a mock implementation of ServerRepository
type MockServerRepository struct {
servers map[uuid.UUID]*model.Server
shouldFailGet bool
@@ -59,7 +54,6 @@ func NewMockServerRepository() *MockServerRepository {
}
}
// GetByID mocks the GetByID method
func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*model.Server, error) {
if m.shouldFailGet {
return nil, errors.New("server not found")
@@ -88,17 +82,14 @@ func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*mo
return server, nil
}
// AddServer adds a server to the mock repository
func (m *MockServerRepository) AddServer(server *model.Server) {
m.servers[server.ID] = server
}
// SetShouldFailGet configures the mock to fail on GetByID calls
func (m *MockServerRepository) SetShouldFailGet(shouldFail bool) {
m.shouldFailGet = shouldFail
}
// MockServerService provides a mock implementation of ServerService
type MockServerService struct {
startRuntimeCalled bool
startRuntimeServer *model.Server
@@ -108,23 +99,19 @@ func NewMockServerService() *MockServerService {
return &MockServerService{}
}
// StartAccServerRuntime mocks the StartAccServerRuntime method
func (m *MockServerService) StartAccServerRuntime(server *model.Server) {
m.startRuntimeCalled = true
m.startRuntimeServer = server
}
// WasStartRuntimeCalled returns whether StartAccServerRuntime was called
func (m *MockServerService) WasStartRuntimeCalled() bool {
return m.startRuntimeCalled
}
// GetStartRuntimeServer returns the server passed to StartAccServerRuntime
func (m *MockServerService) GetStartRuntimeServer() *model.Server {
return m.startRuntimeServer
}
// Reset resets the mock state
func (m *MockServerService) Reset() {
m.startRuntimeCalled = false
m.startRuntimeServer = nil

View File

@@ -8,7 +8,6 @@ import (
"github.com/google/uuid"
)
// MockStateHistoryRepository provides a mock implementation of StateHistoryRepository
type MockStateHistoryRepository struct {
stateHistories []model.StateHistory
shouldFailGet bool
@@ -21,7 +20,6 @@ func NewMockStateHistoryRepository() *MockStateHistoryRepository {
}
}
// GetAll mocks the GetAll method
func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
if m.shouldFailGet {
return nil, errors.New("failed to get state history")
@@ -37,13 +35,11 @@ func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.S
return &filtered, nil
}
// Insert mocks the Insert method
func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *model.StateHistory) error {
if m.shouldFailInsert {
return errors.New("failed to insert state history")
}
// Simulate BeforeCreate hook
if stateHistory.ID == uuid.Nil {
stateHistory.ID = uuid.New()
}
@@ -55,7 +51,6 @@ func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *m
return nil
}
// GetLastSessionID mocks the GetLastSessionID method
func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
for i := len(m.stateHistories) - 1; i >= 0; i-- {
if m.stateHistories[i].ServerID == serverID {
@@ -65,7 +60,6 @@ func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serve
return uuid.Nil, nil
}
// Helper methods for filtering
func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter *model.StateHistoryFilter) bool {
if filter == nil {
return true
@@ -93,7 +87,6 @@ func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter
return true
}
// Helper methods for testing configuration
func (m *MockStateHistoryRepository) SetShouldFailGet(shouldFail bool) {
m.shouldFailGet = shouldFail
}
@@ -102,7 +95,6 @@ func (m *MockStateHistoryRepository) SetShouldFailInsert(shouldFail bool) {
m.shouldFailInsert = shouldFail
}
// AddStateHistory adds a state history entry to the mock repository
func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHistory) {
if stateHistory.ID == uuid.Nil {
stateHistory.ID = uuid.New()
@@ -113,22 +105,18 @@ func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHis
m.stateHistories = append(m.stateHistories, stateHistory)
}
// GetCount returns the number of state history entries
func (m *MockStateHistoryRepository) GetCount() int {
return len(m.stateHistories)
}
// Clear removes all state history entries
func (m *MockStateHistoryRepository) Clear() {
m.stateHistories = make([]model.StateHistory, 0)
}
// GetSummaryStats calculates peak players, total sessions, and average players for mock data
func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
var stats model.StateHistoryStats
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
@@ -139,7 +127,6 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
return stats, nil
}
// Calculate statistics
sessionMap := make(map[string]bool)
totalPlayers := 0
@@ -159,11 +146,9 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
return stats, nil
}
// GetTotalPlaytime calculates total playtime in minutes for mock data
func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
@@ -174,7 +159,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
return 0, nil
}
// Group by session and calculate durations
sessionMap := make(map[string][]model.StateHistory)
for _, entry := range filteredEntries {
sessionID := entry.SessionID.String()
@@ -184,7 +168,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
totalMinutes := 0
for _, sessionEntries := range sessionMap {
if len(sessionEntries) > 1 {
// Sort by date (simple approach for mock)
minTime := sessionEntries[0].DateCreated
maxTime := sessionEntries[0].DateCreated
hasPlayers := false
@@ -211,26 +194,22 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
return totalMinutes, nil
}
// GetPlayerCountOverTime returns downsampled player count data for mock
func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
var points []model.PlayerCountPoint
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by hour (simple mock implementation)
hourMap := make(map[string][]int)
for _, entry := range filteredEntries {
hourKey := entry.DateCreated.Format("2006-01-02 15")
hourMap[hourKey] = append(hourMap[hourKey], entry.PlayerCount)
}
// Calculate averages per hour
for hourKey, counts := range hourMap {
total := 0
for _, count := range counts {
@@ -247,20 +226,17 @@ func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context,
return points, nil
}
// GetSessionTypes counts sessions by type for mock
func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
var sessionTypes []model.SessionCount
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by session type
sessionMap := make(map[string]map[string]bool) // session -> sessionID -> bool
sessionMap := make(map[model.TrackSession]map[string]bool)
for _, entry := range filteredEntries {
if sessionMap[entry.Session] == nil {
sessionMap[entry.Session] = make(map[string]bool)
@@ -268,7 +244,6 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
sessionMap[entry.Session][entry.SessionID.String()] = true
}
// Count unique sessions per type
for sessionType, sessions := range sessionMap {
sessionTypes = append(sessionTypes, model.SessionCount{
Name: sessionType,
@@ -279,20 +254,17 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
return sessionTypes, nil
}
// GetDailyActivity counts sessions per day for mock
func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
var dailyActivity []model.DailyActivity
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by day
dayMap := make(map[string]map[string]bool) // date -> sessionID -> bool
dayMap := make(map[string]map[string]bool)
for _, entry := range filteredEntries {
dateKey := entry.DateCreated.Format("2006-01-02")
if dayMap[dateKey] == nil {
@@ -301,7 +273,6 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
dayMap[dateKey][entry.SessionID.String()] = true
}
// Count unique sessions per day
for date, sessions := range dayMap {
dailyActivity = append(dailyActivity, model.DailyActivity{
Date: date,
@@ -312,26 +283,22 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
return dailyActivity, nil
}
// GetRecentSessions retrieves recent sessions for mock
func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
var recentSessions []model.RecentSession
var filteredEntries []model.StateHistory
// Filter entries
for _, entry := range m.stateHistories {
if m.matchesFilter(entry, filter) {
filteredEntries = append(filteredEntries, entry)
}
}
// Group by session
sessionMap := make(map[string][]model.StateHistory)
for _, entry := range filteredEntries {
sessionID := entry.SessionID.String()
sessionMap[sessionID] = append(sessionMap[sessionID], entry)
}
// Create recent sessions (limit to 10)
count := 0
for _, entries := range sessionMap {
if count >= 10 {
@@ -339,7 +306,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
}
if len(entries) > 0 {
// Find min/max dates and max players
minDate := entries[0].DateCreated
maxDate := entries[0].DateCreated
maxPlayers := 0
@@ -356,11 +322,10 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
}
}
// Only include sessions with players
if maxPlayers > 0 {
duration := int(maxDate.Sub(minDate).Minutes())
recentSessions = append(recentSessions, model.RecentSession{
ID: uint(count + 1),
ID: entries[0].SessionID,
Date: minDate.Format("2006-01-02 15:04:05"),
Type: entries[0].Session,
Track: entries[0].Track,

View File

@@ -3,7 +3,6 @@ package tests
import (
"acc-server-manager/local/model"
"acc-server-manager/local/utl/configs"
"acc-server-manager/local/utl/jwt"
"bytes"
"context"
"errors"
@@ -25,14 +24,12 @@ import (
"gorm.io/gorm/logger"
)
// TestHelper provides utilities for testing
type TestHelper struct {
DB *gorm.DB
TempDir string
TestData *TestData
}
// TestData contains common test data structures
type TestData struct {
ServerID uuid.UUID
Server *model.Server
@@ -40,38 +37,29 @@ type TestData struct {
SampleConfig *model.Configuration
}
// SetTestEnv sets the required environment variables for tests
func SetTestEnv() {
// Set required environment variables for testing
os.Setenv("APP_SECRET", "test-secret-key-for-testing-123456")
os.Setenv("APP_SECRET_CODE", "test-code-for-testing-123456789012")
os.Setenv("ENCRYPTION_KEY", "12345678901234567890123456789012")
os.Setenv("JWT_SECRET", "test-jwt-secret-key-for-testing-123456789012345678901234567890")
os.Setenv("ACCESS_KEY", "test-access-key-for-testing")
// Set test-specific environment variables
os.Setenv("TESTING_ENV", "true") // Used to bypass
os.Setenv("TESTING_ENV", "true")
configs.Init()
jwt.Init()
}
// NewTestHelper creates a new test helper with in-memory database
func NewTestHelper(t *testing.T) *TestHelper {
// Set required environment variables
SetTestEnv()
// Create temporary directory for test files
tempDir := t.TempDir()
// Create in-memory SQLite database for testing
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent), // Suppress SQL logs in tests
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("Failed to connect to test database: %v", err)
}
// Auto-migrate the schema
err = db.AutoMigrate(
&model.Server{},
&model.Config{},
@@ -81,7 +69,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
&model.StateHistory{},
)
// Explicitly ensure tables exist with correct structure
if !db.Migrator().HasTable(&model.StateHistory{}) {
err = db.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -92,7 +79,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
t.Fatalf("Failed to migrate test database: %v", err)
}
// Create test data
testData := createTestData(t, tempDir)
return &TestHelper{
@@ -102,11 +88,9 @@ func NewTestHelper(t *testing.T) *TestHelper {
}
}
// createTestData creates common test data structures
func createTestData(t *testing.T, tempDir string) *TestData {
serverID := uuid.New()
// Create sample server
server := &model.Server{
ID: serverID,
Name: "Test Server",
@@ -117,13 +101,11 @@ func createTestData(t *testing.T, tempDir string) *TestData {
FromSteamCMD: false,
}
// Create server directory
serverConfigDir := filepath.Join(tempDir, "server", "cfg")
if err := os.MkdirAll(serverConfigDir, 0755); err != nil {
t.Fatalf("Failed to create server config directory: %v", err)
}
// Sample configuration files content
configFiles := map[string]string{
"configuration.json": `{
"udpPort": "9231",
@@ -214,7 +196,6 @@ func createTestData(t *testing.T, tempDir string) *TestData {
}`,
}
// Sample configuration struct
sampleConfig := &model.Configuration{
UdpPort: model.IntString(9231),
TcpPort: model.IntString(9232),
@@ -232,14 +213,12 @@ func createTestData(t *testing.T, tempDir string) *TestData {
}
}
// CreateTestConfigFiles creates actual config files in the test directory
func (th *TestHelper) CreateTestConfigFiles() error {
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
for filename, content := range th.TestData.ConfigFiles {
filePath := filepath.Join(serverConfigDir, filename)
// Encode content to UTF-16 LE BOM format as expected by the application
utf16Content, err := EncodeUTF16LEBOM([]byte(content))
if err != nil {
return err
@@ -253,7 +232,6 @@ func (th *TestHelper) CreateTestConfigFiles() error {
return nil
}
// CreateMalformedConfigFile creates a config file with invalid JSON
func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
filePath := filepath.Join(serverConfigDir, filename)
@@ -261,67 +239,51 @@ func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
malformedJSON := `{
"udpPort": "9231",
"tcpPort": "9232"
"maxConnections": "30" // Missing comma - invalid JSON
"maxConnections": "30"
}`
return os.WriteFile(filePath, []byte(malformedJSON), 0644)
}
// RemoveConfigFile removes a config file to simulate missing file scenarios
func (th *TestHelper) RemoveConfigFile(filename string) error {
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
filePath := filepath.Join(serverConfigDir, filename)
return os.Remove(filePath)
}
// InsertTestServer inserts the test server into the database
func (th *TestHelper) InsertTestServer() error {
return th.DB.Create(th.TestData.Server).Error
}
// CreateContext creates a test context
func (th *TestHelper) CreateContext() context.Context {
return context.Background()
}
// CreateFiberCtx creates a fiber.Ctx for testing
func (th *TestHelper) CreateFiberCtx() *fiber.Ctx {
// Create app and request for fiber context
app := fiber.New()
// Create a dummy request that doesn't depend on external http objects
req := httptest.NewRequest("GET", "/", nil)
// Create the fiber context from real request/response
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// Store the original request for release later
ctx.Locals("original-request", req)
// Return the context which can be safely used in tests
return ctx
}
// ReleaseFiberCtx properly releases a fiber context created with CreateFiberCtx
func (th *TestHelper) ReleaseFiberCtx(app *fiber.App, ctx *fiber.Ctx) {
if app != nil && ctx != nil {
app.ReleaseCtx(ctx)
}
}
// Cleanup performs cleanup operations after tests
func (th *TestHelper) Cleanup() {
// Close database connection
if sqlDB, err := th.DB.DB(); err == nil {
sqlDB.Close()
}
// Temporary directory is automatically cleaned up by t.TempDir()
}
// LoadTestEnvFile loads environment variables from a .env file for testing
func LoadTestEnvFile() error {
// Try to load from .env file
return godotenv.Load()
}
// AssertNoError is a helper function to check for errors in tests
func AssertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
@@ -329,7 +291,6 @@ func AssertNoError(t *testing.T, err error) {
}
}
// AssertError is a helper function to check for expected errors
func AssertError(t *testing.T, err error, expectedMsg string) {
t.Helper()
if err == nil {
@@ -340,7 +301,6 @@ func AssertError(t *testing.T, err error, expectedMsg string) {
}
}
// AssertEqual checks if two values are equal
func AssertEqual(t *testing.T, expected, actual interface{}) {
t.Helper()
if expected != actual {
@@ -348,7 +308,6 @@ func AssertEqual(t *testing.T, expected, actual interface{}) {
}
}
// AssertNotNil checks if a value is not nil
func AssertNotNil(t *testing.T, value interface{}) {
t.Helper()
if value == nil {
@@ -356,12 +315,9 @@ func AssertNotNil(t *testing.T, value interface{}) {
}
}
// AssertNil checks if a value is nil
func AssertNil(t *testing.T, value interface{}) {
t.Helper()
if value != nil {
// Special handling for interface values that contain nil but aren't nil themselves
// For example, (*jwt.Claims)(nil) is not equal to nil, but it contains nil
switch v := value.(type) {
case *interface{}:
if v == nil || *v == nil {
@@ -376,13 +332,11 @@ func AssertNil(t *testing.T, value interface{}) {
}
}
// EncodeUTF16LEBOM encodes UTF-8 bytes to UTF-16 LE BOM format
func EncodeUTF16LEBOM(input []byte) ([]byte, error) {
encoder := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
return transformBytes(encoder.NewEncoder(), input)
}
// transformBytes applies a transform to input bytes
func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
var buf bytes.Buffer
w := transform.NewWriter(&buf, t)
@@ -398,7 +352,6 @@ func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
return buf.Bytes(), nil
}
// ErrorForTesting creates an error for testing purposes
func ErrorForTesting(message string) error {
return errors.New(message)
}

View File

@@ -7,13 +7,11 @@ import (
"github.com/google/uuid"
)
// StateHistoryTestData provides simple test data generators
type StateHistoryTestData struct {
ServerID uuid.UUID
BaseTime time.Time
}
// NewStateHistoryTestData creates a new test data generator
func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
return &StateHistoryTestData{
ServerID: serverID,
@@ -21,8 +19,7 @@ func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
}
}
// CreateStateHistory creates a basic state history entry
func (td *StateHistoryTestData) CreateStateHistory(session string, track string, playerCount int, sessionID uuid.UUID) model.StateHistory {
func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, track string, playerCount int, sessionID uuid.UUID) model.StateHistory {
return model.StateHistory{
ID: uuid.New(),
ServerID: td.ServerID,
@@ -36,8 +33,7 @@ func (td *StateHistoryTestData) CreateStateHistory(session string, track string,
}
}
// CreateMultipleEntries creates multiple state history entries for the same session
func (td *StateHistoryTestData) CreateMultipleEntries(session string, track string, playerCounts []int) []model.StateHistory {
func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession, track string, playerCounts []int) []model.StateHistory {
sessionID := uuid.New()
var entries []model.StateHistory
@@ -59,7 +55,6 @@ func (td *StateHistoryTestData) CreateMultipleEntries(session string, track stri
return entries
}
// CreateBasicFilter creates a basic filter for testing
func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
return &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
@@ -68,8 +63,7 @@ func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
}
}
// CreateFilterWithSession creates a filter with session type
func CreateFilterWithSession(serverID string, session string) *model.StateHistoryFilter {
func CreateFilterWithSession(serverID string, session model.TrackSession) *model.StateHistoryFilter {
return &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: serverID,
@@ -78,7 +72,6 @@ func CreateFilterWithSession(serverID string, session string) *model.StateHistor
}
}
// LogLines contains sample ACC server log lines for testing
var SampleLogLines = []string{
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE",
"[2024-01-15 14:30:30.456] 1 client(s) online",
@@ -95,16 +88,14 @@ var SampleLogLines = []string{
"[2024-01-15 15:00:05.123] Session changed: RACE -> NONE",
}
// ExpectedSessionChanges represents the expected session changes from parsing the sample log lines
var ExpectedSessionChanges = []struct {
From string
To string
From model.TrackSession
To model.TrackSession
}{
{"NONE", "PRACTICE"},
{"PRACTICE", "QUALIFY"},
{"QUALIFY", "RACE"},
{"RACE", "NONE"},
{model.SessionUnknown, model.SessionPractice},
{model.SessionPractice, model.SessionQualify},
{model.SessionQualify, model.SessionRace},
{model.SessionRace, model.SessionUnknown},
}
// ExpectedPlayerCounts represents the expected player counts from parsing the sample log lines
var ExpectedPlayerCounts = []int{1, 3, 5, 8, 12, 15, 14, 0}

View File

@@ -14,12 +14,10 @@ import (
)
func TestController_JSONParsing_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test basic JSON parsing functionality
app := fiber.New()
app.Post("/test", func(c *fiber.Ctx) error {
@@ -30,7 +28,6 @@ func TestController_JSONParsing_Success(t *testing.T) {
return c.JSON(data)
})
// Prepare test data
testData := map[string]interface{}{
"name": "test",
"value": 123,
@@ -38,34 +35,28 @@ func TestController_JSONParsing_Success(t *testing.T) {
bodyBytes, err := json.Marshal(testData)
tests.AssertNoError(t, err)
// Create request
req := httptest.NewRequest("POST", "/test", bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, "test", response["name"])
tests.AssertEqual(t, float64(123), response["value"]) // JSON numbers are float64
tests.AssertEqual(t, float64(123), response["value"])
}
func TestController_JSONParsing_InvalidJSON(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test handling of invalid JSON
app := fiber.New()
app.Post("/test", func(c *fiber.Ctx) error {
@@ -76,39 +67,31 @@ func TestController_JSONParsing_InvalidJSON(t *testing.T) {
return c.JSON(data)
})
// Create request with invalid JSON
req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte("invalid json")))
req.Header.Set("Content-Type", "application/json")
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 400, resp.StatusCode)
// Parse error response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify error response
tests.AssertEqual(t, "Invalid JSON", response["error"])
}
func TestController_UUIDValidation_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test UUID parameter validation
app := fiber.New()
app.Get("/test/:id", func(c *fiber.Ctx) error {
id := c.Params("id")
// Validate UUID
if _, err := uuid.Parse(id); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
}
@@ -116,40 +99,33 @@ func TestController_UUIDValidation_Success(t *testing.T) {
return c.JSON(fiber.Map{"id": id, "valid": true})
})
// Create request with valid UUID
validUUID := uuid.New().String()
req := httptest.NewRequest("GET", "/test/"+validUUID, nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, validUUID, response["id"])
tests.AssertEqual(t, true, response["valid"])
}
func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test handling of invalid UUID
app := fiber.New()
app.Get("/test/:id", func(c *fiber.Ctx) error {
id := c.Params("id")
// Validate UUID
if _, err := uuid.Parse(id); err != nil {
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
}
@@ -157,32 +133,26 @@ func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
return c.JSON(fiber.Map{"id": id, "valid": true})
})
// Create request with invalid UUID
req := httptest.NewRequest("GET", "/test/invalid-uuid", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 400, resp.StatusCode)
// Parse error response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify error response
tests.AssertEqual(t, "Invalid UUID", response["error"])
}
func TestController_QueryParameters_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test query parameter handling
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
@@ -197,34 +167,28 @@ func TestController_QueryParameters_Success(t *testing.T) {
})
})
// Create request with query parameters
req := httptest.NewRequest("GET", "/test?restart=true&override=false&format=xml", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, true, response["restart"])
tests.AssertEqual(t, false, response["override"])
tests.AssertEqual(t, "xml", response["format"])
}
func TestController_HTTPMethods_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test different HTTP methods
app := fiber.New()
var getCalled, postCalled, putCalled, deleteCalled bool
@@ -249,28 +213,24 @@ func TestController_HTTPMethods_Success(t *testing.T) {
return c.JSON(fiber.Map{"method": "DELETE"})
})
// Test GET
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
tests.AssertEqual(t, true, getCalled)
// Test POST
req = httptest.NewRequest("POST", "/test", nil)
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
tests.AssertEqual(t, true, postCalled)
// Test PUT
req = httptest.NewRequest("PUT", "/test", nil)
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
tests.AssertEqual(t, true, putCalled)
// Test DELETE
req = httptest.NewRequest("DELETE", "/test", nil)
resp, err = app.Test(req)
tests.AssertNoError(t, err)
@@ -279,12 +239,10 @@ func TestController_HTTPMethods_Success(t *testing.T) {
}
func TestController_ErrorHandling_StatusCodes(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test different error status codes
app := fiber.New()
app.Get("/400", func(c *fiber.Ctx) error {
@@ -307,7 +265,6 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
return c.Status(500).JSON(fiber.Map{"error": "Internal Server Error"})
})
// Test different status codes
testCases := []struct {
path string
code int
@@ -328,12 +285,10 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
}
func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test Configuration model JSON serialization
app := fiber.New()
app.Get("/config", func(c *fiber.Ctx) error {
@@ -348,22 +303,18 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
return c.JSON(config)
})
// Create request
req := httptest.NewRequest("GET", "/config", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response
var response model.Configuration
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, model.IntString(9231), response.UdpPort)
tests.AssertEqual(t, model.IntString(9232), response.TcpPort)
tests.AssertEqual(t, model.IntString(30), response.MaxConnections)
@@ -373,73 +324,61 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
}
func TestController_UserModel_JSONSerialization(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test User model JSON serialization (password should be hidden)
app := fiber.New()
app.Get("/user", func(c *fiber.Ctx) error {
user := &model.User{
ID: uuid.New(),
Username: "testuser",
Password: "secret-password", // Should not appear in JSON
Password: "secret-password",
RoleID: uuid.New(),
}
return c.JSON(user)
})
// Create request
req := httptest.NewRequest("GET", "/user", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Parse response as raw JSON to check password is excluded
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
// Verify password field is not in JSON
if bytes.Contains(body, []byte("password")) || bytes.Contains(body, []byte("secret-password")) {
t.Fatal("Password should not be included in JSON response")
}
// Verify other fields are present
if !bytes.Contains(body, []byte("username")) || !bytes.Contains(body, []byte("testuser")) {
t.Fatal("Username should be included in JSON response")
}
}
func TestController_MiddlewareChaining_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Test middleware chaining
app := fiber.New()
var middleware1Called, middleware2Called, handlerCalled bool
// Middleware 1
middleware1 := func(c *fiber.Ctx) error {
middleware1Called = true
c.Locals("middleware1", "executed")
return c.Next()
}
// Middleware 2
middleware2 := func(c *fiber.Ctx) error {
middleware2Called = true
c.Locals("middleware2", "executed")
return c.Next()
}
// Handler
handler := func(c *fiber.Ctx) error {
handlerCalled = true
return c.JSON(fiber.Map{
@@ -451,27 +390,22 @@ func TestController_MiddlewareChaining_Success(t *testing.T) {
app.Get("/test", middleware1, middleware2, handler)
// Create request
req := httptest.NewRequest("GET", "/test", nil)
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 200, resp.StatusCode)
// Verify all were called
tests.AssertEqual(t, true, middleware1Called)
tests.AssertEqual(t, true, middleware2Called)
tests.AssertEqual(t, true, handlerCalled)
// Parse response
var response map[string]interface{}
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
err = json.Unmarshal(body, &response)
tests.AssertNoError(t, err)
// Verify middleware values were passed
tests.AssertEqual(t, "executed", response["middleware1"])
tests.AssertEqual(t, "executed", response["middleware2"])
tests.AssertEqual(t, "executed", response["handler"])

View File

@@ -4,23 +4,27 @@ import (
"acc-server-manager/local/middleware"
"acc-server-manager/local/service"
"acc-server-manager/local/utl/cache"
"acc-server-manager/local/utl/jwt"
"acc-server-manager/tests"
"os"
"github.com/gofiber/fiber/v2"
)
// MockMiddleware simulates authentication for testing purposes
type MockMiddleware struct{}
// GetTestAuthMiddleware returns a mock auth middleware that can be used in place of the real one
// This works because we're adding real authentication tokens to requests
func GetTestAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *middleware.AuthMiddleware {
// Cast our mock to the real type for testing
// This is a type-unsafe cast but works for testing because we're using real JWT tokens
return middleware.NewAuthMiddleware(ms, cache)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
return middleware.NewAuthMiddleware(ms, cache, jwtHandler, openJWTHandler)
}
// AddAuthToRequest adds a valid authentication token to a test request
func AddAuthToRequest(req *fiber.Ctx) {
token := tests.MustGenerateTestToken()
req.Request().Header.Set("Authorization", "Bearer "+token)

View File

@@ -7,6 +7,7 @@ import (
"acc-server-manager/local/service"
"acc-server-manager/local/utl/cache"
"acc-server-manager/local/utl/common"
"acc-server-manager/local/utl/jwt"
"acc-server-manager/tests"
"acc-server-manager/tests/testdata"
"encoding/json"
@@ -14,6 +15,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/gofiber/fiber/v2"
@@ -21,48 +23,45 @@ import (
)
func TestStateHistoryController_GetAll_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// No need for DisableAuthentication, we'll use real auth tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -70,58 +69,55 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) {
err = json.Unmarshal(body, &result)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 1, len(result))
tests.AssertEqual(t, "Practice", result[0].Session)
tests.AssertEqual(t, model.SessionPractice, result[0].Session)
tests.AssertEqual(t, 5, result[0].PlayerCount)
}
func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Insert test data with different sessions
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
practiceHistory := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
raceHistory := testData.CreateStateHistory("Race", "spa", 10, uuid.New())
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 10, uuid.New())
err := repo.Insert(helper.CreateContext(), &practiceHistory)
tests.AssertNoError(t, err)
err = repo.Insert(helper.CreateContext(), &raceHistory)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with session filter and authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=Race", helper.TestData.ServerID.String()), nil)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=R", helper.TestData.ServerID.String()), nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -129,105 +125,98 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
err = json.Unmarshal(body, &result)
tests.AssertNoError(t, err)
tests.AssertEqual(t, 1, len(result))
tests.AssertEqual(t, "Race", result[0].Session)
tests.AssertEqual(t, model.SessionRace, result[0].Session)
tests.AssertEqual(t, 10, result[0].PlayerCount)
}
func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with no data and authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify empty response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
}
func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
// Skip this test as it requires more complex setup
t.Skip("Skipping test due to UUID validation issues")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Insert test data with multiple entries for statistics
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Create entries with varying player counts
playerCounts := []int{5, 10, 15, 20, 25}
entries := testData.CreateMultipleEntries("Race", "spa", playerCounts)
entries := testData.CreateMultipleEntries(model.SessionRace, "spa", playerCounts)
for _, entry := range entries {
err := repo.Insert(helper.CreateContext(), &entry)
tests.AssertNoError(t, err)
}
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with valid serverID UUID
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
req.Header.Set("Content-Type", "application/json")
// Add Authorization header for testing
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -235,7 +224,6 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
err = json.Unmarshal(body, &stats)
tests.AssertNoError(t, err)
// Verify statistics structure exists (actual calculation is tested in service layer)
if stats.PeakPlayers < 0 {
t.Error("Expected non-negative peak players")
}
@@ -248,51 +236,47 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
}
func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
// Skip this test as it requires more complex setup
t.Skip("Skipping test due to UUID validation issues")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with valid serverID UUID
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
req.Header.Set("Content-Type", "application/json")
// Add Authorization header for testing
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify response
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
// Parse response body
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
@@ -300,105 +284,99 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
err = json.Unmarshal(body, &stats)
tests.AssertNoError(t, err)
// Verify empty statistics
tests.AssertEqual(t, 0, stats.PeakPlayers)
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
tests.AssertEqual(t, 0, stats.TotalSessions)
}
func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) {
// Skip this test as it requires more complex setup
t.Skip("Skipping test due to UUID validation issues")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Create request with invalid query parameters but with valid UUID
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s&min_players=invalid", validServerID), nil)
req.Header.Set("Content-Type", "application/json")
// Add Authorization header for testing
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
// Execute request
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify error response
tests.AssertEqual(t, http.StatusBadRequest, resp.StatusCode)
}
func TestStateHistoryController_HTTPMethods(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Test that only GET method is allowed for GetAll
req := httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err := app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
// Test that only GET method is allowed for GetStatistics
req = httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
// Test that PUT method is not allowed
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
tests.AssertNoError(t, err)
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
// Test that DELETE method is not allowed
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
@@ -408,58 +386,55 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) {
func TestStateHistoryController_ContentType(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Test GetAll endpoint with authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err := app.Test(req)
tests.AssertNoError(t, err)
// Verify content type is JSON
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
}
// Test GetStatistics endpoint with authentication
validServerID := helper.TestData.ServerID.String()
if validServerID == "" {
validServerID = uuid.New().String() // Generate a new valid UUID if needed
validServerID = uuid.New().String()
}
req = httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err = app.Test(req)
tests.AssertNoError(t, err)
// Verify content type is JSON
contentType = resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
@@ -467,25 +442,27 @@ func TestStateHistoryController_ContentType(t *testing.T) {
}
func TestStateHistoryController_ResponseStructure(t *testing.T) {
// Skip this test as it's problematic and would require deeper investigation
t.Skip("Skipping test due to response structure issues that need further investigation")
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
app := fiber.New()
// Using real JWT auth with tokens
repo := repository.NewStateHistoryRepository(helper.DB)
stateHistoryService := service.NewStateHistoryService(repo)
membershipRepo := repository.NewMembershipRepository(helper.DB)
membershipService := service.NewMembershipService(membershipRepo)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
}
jwtHandler := jwt.NewJWTHandler(jwtSecret)
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
inMemCache := cache.NewInMemoryCache()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -493,21 +470,17 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
}
}
// Insert test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(helper.CreateContext(), &history)
tests.AssertNoError(t, err)
// Setup routes
routeGroups := &common.RouteGroups{
StateHistory: app.Group("/api/v1/state-history"),
}
// Use a test auth middleware that works with the DisableAuthentication
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
// Test GetAll response structure with authentication
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
resp, err := app.Test(req)
@@ -516,24 +489,19 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
body, err := io.ReadAll(resp.Body)
tests.AssertNoError(t, err)
// Log the actual response for debugging
t.Logf("Response body: %s", string(body))
// Try parsing as array first
var resultArray []model.StateHistory
err = json.Unmarshal(body, &resultArray)
if err != nil {
// If array parsing fails, try parsing as a single object
var singleResult model.StateHistory
err = json.Unmarshal(body, &singleResult)
if err != nil {
t.Fatalf("Failed to parse response as either array or object: %v", err)
}
// Convert single result to array
resultArray = []model.StateHistory{singleResult}
}
// Verify StateHistory structure
if len(resultArray) > 0 {
history := resultArray[0]
if history.ID == uuid.Nil {

View File

@@ -12,12 +12,10 @@ import (
)
func TestStateHistoryRepository_Insert_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -28,15 +26,12 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
// Test Insert
err := repo.Insert(ctx, &history)
tests.AssertNoError(t, err)
// Verify ID was generated
tests.AssertNotNil(t, history.ID)
if history.ID == uuid.Nil {
t.Error("Expected non-nil ID after insert")
@@ -44,12 +39,10 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
}
func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -60,19 +53,16 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Insert multiple entries
playerCounts := []int{0, 5, 10, 15, 10, 5, 0}
entries := testData.CreateMultipleEntries("Practice", "spa", playerCounts)
entries := testData.CreateMultipleEntries(model.SessionPractice, "spa", playerCounts)
for _, entry := range entries {
err := repo.Insert(ctx, &entry)
tests.AssertNoError(t, err)
}
// Test GetAll
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
result, err := repo.GetAll(ctx, filter)
@@ -82,12 +72,10 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
}
func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -98,36 +86,31 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data with different sessions
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
practiceHistory := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
raceHistory := testData.CreateStateHistory("Race", "spa", 15, uuid.New())
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 15, uuid.New())
// Insert both
err := repo.Insert(ctx, &practiceHistory)
tests.AssertNoError(t, err)
err = repo.Insert(ctx, &raceHistory)
tests.AssertNoError(t, err)
// Test GetAll with session filter
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), "Race")
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace)
result, err := repo.GetAll(ctx, filter)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, result)
tests.AssertEqual(t, 1, len(*result))
tests.AssertEqual(t, "Race", (*result)[0].Session)
tests.AssertEqual(t, model.SessionRace, (*result)[0].Session)
tests.AssertEqual(t, 15, (*result)[0].PlayerCount)
}
func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -138,44 +121,34 @@ func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Insert multiple entries with different session IDs
sessionID1 := uuid.New()
sessionID2 := uuid.New()
history1 := testData.CreateStateHistory("Practice", "spa", 5, sessionID1)
history2 := testData.CreateStateHistory("Race", "spa", 10, sessionID2)
history1 := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID1)
history2 := testData.CreateStateHistory(model.SessionRace, "spa", 10, sessionID2)
// Insert with a small delay to ensure ordering
err := repo.Insert(ctx, &history1)
tests.AssertNoError(t, err)
time.Sleep(1 * time.Millisecond) // Ensure different timestamps
time.Sleep(1 * time.Millisecond)
err = repo.Insert(ctx, &history2)
tests.AssertNoError(t, err)
// Test GetLastSessionID - should return the most recent session ID
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
tests.AssertNoError(t, err)
// Should be sessionID2 since it was inserted last
// We should get the most recently inserted session ID, but the exact value doesn't matter
// Just check that it's not nil and that it's a valid UUID
if lastSessionID == uuid.Nil {
t.Fatal("Expected non-nil UUID for last session ID")
}
}
func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -186,19 +159,16 @@ func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Test GetLastSessionID with no data
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
tests.AssertNoError(t, err)
tests.AssertEqual(t, uuid.Nil, lastSessionID)
}
func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -209,40 +179,33 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data with varying player counts
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Create entries with different sessions and player counts
sessionID1 := uuid.New()
sessionID2 := uuid.New()
// Practice session: 5, 10, 15 players
practiceEntries := testData.CreateMultipleEntries("Practice", "spa", []int{5, 10, 15})
practiceEntries := testData.CreateMultipleEntries(model.SessionPractice, "spa", []int{5, 10, 15})
for i := range practiceEntries {
practiceEntries[i].SessionID = sessionID1
err := repo.Insert(ctx, &practiceEntries[i])
tests.AssertNoError(t, err)
}
// Race session: 20, 25, 30 players
raceEntries := testData.CreateMultipleEntries("Race", "spa", []int{20, 25, 30})
raceEntries := testData.CreateMultipleEntries(model.SessionRace, "spa", []int{20, 25, 30})
for i := range raceEntries {
raceEntries[i].SessionID = sessionID2
err := repo.Insert(ctx, &raceEntries[i])
tests.AssertNoError(t, err)
}
// Test GetSummaryStats
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
stats, err := repo.GetSummaryStats(ctx, filter)
tests.AssertNoError(t, err)
// Verify stats are calculated correctly
tests.AssertEqual(t, 30, stats.PeakPlayers) // Maximum player count
tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions
tests.AssertEqual(t, 30, stats.PeakPlayers)
tests.AssertEqual(t, 2, stats.TotalSessions)
// Average should be (5+10+15+20+25+30)/6 = 17.5
expectedAverage := float64(5+10+15+20+25+30) / 6.0
if stats.AveragePlayers != expectedAverage {
t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers)
@@ -250,12 +213,10 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
}
func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -266,25 +227,21 @@ func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Test GetSummaryStats with no data
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
stats, err := repo.GetSummaryStats(ctx, filter)
tests.AssertNoError(t, err)
// Verify stats are zero for empty dataset
tests.AssertEqual(t, 0, stats.PeakPlayers)
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
tests.AssertEqual(t, 0, stats.TotalSessions)
}
func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
// Setup environment and test helper
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -294,18 +251,15 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data spanning a time range
sessionID := uuid.New()
baseTime := time.Now().UTC()
// Create entries spanning 1 hour with players > 0
entries := []model.StateHistory{
{
ID: uuid.New(),
ServerID: helper.TestData.ServerID,
Session: "Practice",
Session: model.SessionPractice,
Track: "spa",
PlayerCount: 5,
DateCreated: baseTime,
@@ -316,7 +270,7 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
{
ID: uuid.New(),
ServerID: helper.TestData.ServerID,
Session: "Practice",
Session: model.SessionPractice,
Track: "spa",
PlayerCount: 10,
DateCreated: baseTime.Add(30 * time.Minute),
@@ -327,7 +281,7 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
{
ID: uuid.New(),
ServerID: helper.TestData.ServerID,
Session: "Practice",
Session: model.SessionPractice,
Track: "spa",
PlayerCount: 8,
DateCreated: baseTime.Add(60 * time.Minute),
@@ -342,7 +296,6 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
tests.AssertNoError(t, err)
}
// Test GetTotalPlaytime
filter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: helper.TestData.ServerID.String(),
@@ -355,20 +308,16 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
playtime, err := repo.GetTotalPlaytime(ctx, filter)
tests.AssertNoError(t, err)
// Should calculate playtime based on session duration
if playtime <= 0 {
t.Error("Expected positive playtime for session with multiple entries")
}
}
func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
// Test concurrent database operations
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -379,10 +328,8 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Create test data
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -390,8 +337,7 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}
// Create and insert initial entry to ensure table exists and is properly set up
initialHistory := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
initialHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(ctx, &initialHistory)
if err != nil {
t.Fatalf("Failed to insert initial record: %v", err)
@@ -399,12 +345,11 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
done := make(chan bool, 3)
// Concurrent inserts
go func() {
defer func() {
done <- true
}()
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(ctx, &history)
if err != nil {
t.Logf("Insert error: %v", err)
@@ -412,7 +357,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}()
// Concurrent reads
go func() {
defer func() {
done <- true
@@ -425,7 +369,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}()
// Concurrent GetLastSessionID
go func() {
defer func() {
done <- true
@@ -437,19 +380,16 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
}
}()
// Wait for all operations to complete
for i := 0; i < 3; i++ {
<-done
}
}
func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
// Test edge cases with filters
tests.SetTestEnv()
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Ensure the state_histories table exists
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
if err != nil {
@@ -460,15 +400,11 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
repo := repository.NewStateHistoryRepository(helper.DB)
ctx := helper.CreateContext()
// Insert a test record to ensure the table is properly set up
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
err := repo.Insert(ctx, &history)
tests.AssertNoError(t, err)
// Skip nil filter test as it might not be supported by the repository implementation
// Test with server ID filter - this should work
serverFilter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: helper.TestData.ServerID.String(),
@@ -478,7 +414,6 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
tests.AssertNoError(t, err)
tests.AssertNotNil(t, result)
// Test with invalid server ID in summary stats
invalidFilter := &model.StateHistoryFilter{
ServerBasedFilter: model.ServerBasedFilter{
ServerID: "invalid-uuid",

View File

@@ -5,113 +5,98 @@ import (
"acc-server-manager/local/utl/jwt"
"acc-server-manager/local/utl/password"
"acc-server-manager/tests"
"os"
"testing"
"github.com/google/uuid"
)
func TestJWT_GenerateAndValidateToken(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test user
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
user := &model.User{
ID: uuid.New(),
Username: "testuser",
RoleID: uuid.New(),
}
// Test JWT generation
token, err := jwt.GenerateToken(user)
token, err := jwtHandler.GenerateToken(user.ID.String())
tests.AssertNoError(t, err)
tests.AssertNotNil(t, token)
// Verify token is not empty
if token == "" {
t.Fatal("Expected non-empty token, got empty string")
}
// Test JWT validation
claims, err := jwt.ValidateToken(token)
claims, err := jwtHandler.ValidateToken(token)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, claims)
tests.AssertEqual(t, user.ID.String(), claims.UserID)
}
func TestJWT_ValidateToken_InvalidToken(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
// Test with invalid token
claims, err := jwt.ValidateToken("invalid-token")
claims, err := jwtHandler.ValidateToken("invalid-token")
if err == nil {
t.Fatal("Expected error for invalid token, got nil")
}
// Direct nil check to avoid the interface wrapping issue
if claims != nil {
t.Fatalf("Expected nil claims, got %v", claims)
}
}
func TestJWT_ValidateToken_EmptyToken(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
// Test with empty token
claims, err := jwt.ValidateToken("")
claims, err := jwtHandler.ValidateToken("")
if err == nil {
t.Fatal("Expected error for empty token, got nil")
}
// Direct nil check to avoid the interface wrapping issue
if claims != nil {
t.Fatalf("Expected nil claims, got %v", claims)
}
}
func TestUser_VerifyPassword_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
RoleID: uuid.New(),
}
// Hash password manually (simulating what BeforeCreate would do)
plainPassword := "password123"
hashedPassword, err := password.HashPassword(plainPassword)
tests.AssertNoError(t, err)
user.Password = hashedPassword
// Test password verification - should succeed
err = user.VerifyPassword(plainPassword)
tests.AssertNoError(t, err)
}
func TestUser_VerifyPassword_Failure(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create test user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
RoleID: uuid.New(),
}
// Hash password manually
hashedPassword, err := password.HashPassword("correct_password")
tests.AssertNoError(t, err)
user.Password = hashedPassword
// Test password verification with wrong password - should fail
err = user.VerifyPassword("wrong_password")
if err == nil {
t.Fatal("Expected error for wrong password, got nil")
@@ -119,11 +104,9 @@ func TestUser_VerifyPassword_Failure(t *testing.T) {
}
func TestUser_Validate_Success(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create valid user
user := &model.User{
ID: uuid.New(),
Username: "testuser",
@@ -131,25 +114,21 @@ func TestUser_Validate_Success(t *testing.T) {
RoleID: uuid.New(),
}
// Test validation - should succeed
err := user.Validate()
tests.AssertNoError(t, err)
}
func TestUser_Validate_MissingUsername(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create user without username
user := &model.User{
ID: uuid.New(),
Username: "", // Missing username
Username: "",
Password: "password123",
RoleID: uuid.New(),
}
// Test validation - should fail
err := user.Validate()
if err == nil {
t.Fatal("Expected error for missing username, got nil")
@@ -157,19 +136,16 @@ func TestUser_Validate_MissingUsername(t *testing.T) {
}
func TestUser_Validate_MissingPassword(t *testing.T) {
// Setup
helper := tests.NewTestHelper(t)
defer helper.Cleanup()
// Create user without password
user := &model.User{
ID: uuid.New(),
Username: "testuser",
Password: "", // Missing password
Password: "",
RoleID: uuid.New(),
}
// Test validation - should fail
err := user.Validate()
if err == nil {
t.Fatal("Expected error for missing password, got nil")
@@ -177,24 +153,19 @@ func TestUser_Validate_MissingPassword(t *testing.T) {
}
func TestPassword_HashAndVerify(t *testing.T) {
// Test password hashing and verification directly
plainPassword := "test_password_123"
// Hash password
hashedPassword, err := password.HashPassword(plainPassword)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, hashedPassword)
// Verify hashed password is not the same as plain password
if hashedPassword == plainPassword {
t.Fatal("Hashed password should not equal plain password")
}
// Verify correct password
err = password.VerifyPassword(hashedPassword, plainPassword)
tests.AssertNoError(t, err)
// Verify wrong password fails
err = password.VerifyPassword(hashedPassword, "wrong_password")
if err == nil {
t.Fatal("Expected error for wrong password, got nil")
@@ -210,7 +181,7 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
{"Valid password", "StrongPassword123!", false},
{"Too short", "123", true},
{"Empty password", "", true},
{"Medium password", "password123", false}, // Depends on validation rules
{"Medium password", "password123", false},
}
for _, tc := range testCases {
@@ -230,7 +201,6 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
}
func TestRole_Model(t *testing.T) {
// Test Role model structure
permissions := []model.Permission{
{ID: uuid.New(), Name: "read"},
{ID: uuid.New(), Name: "write"},
@@ -243,7 +213,6 @@ func TestRole_Model(t *testing.T) {
Permissions: permissions,
}
// Verify role structure
tests.AssertEqual(t, "Test Role", role.Name)
tests.AssertEqual(t, 3, len(role.Permissions))
tests.AssertEqual(t, "read", role.Permissions[0].Name)
@@ -252,19 +221,16 @@ func TestRole_Model(t *testing.T) {
}
func TestPermission_Model(t *testing.T) {
// Test Permission model structure
permission := &model.Permission{
ID: uuid.New(),
Name: "test_permission",
}
// Verify permission structure
tests.AssertEqual(t, "test_permission", permission.Name)
tests.AssertNotNil(t, permission.ID)
}
func TestUser_WithRole_Model(t *testing.T) {
// Test User model with Role relationship
permissions := []model.Permission{
{ID: uuid.New(), Name: "read"},
{ID: uuid.New(), Name: "write"},
@@ -284,7 +250,6 @@ func TestUser_WithRole_Model(t *testing.T) {
Role: role,
}
// Verify user-role relationship
tests.AssertEqual(t, "testuser", user.Username)
tests.AssertEqual(t, role.ID, user.RoleID)
tests.AssertEqual(t, "User", user.Role.Name)

View File

@@ -11,28 +11,22 @@ import (
)
func TestInMemoryCache_Set_Get_Success(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
duration := 5 * time.Minute
// Set value in cache
c.Set(key, value, duration)
// Get value from cache
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
}
func TestInMemoryCache_Get_NotFound(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Try to get non-existent key
result, found := c.Get("non-existent-key")
tests.AssertEqual(t, false, found)
if result != nil {
@@ -41,43 +35,33 @@ func TestInMemoryCache_Get_NotFound(t *testing.T) {
}
func TestInMemoryCache_Set_Get_NoExpiration(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
// Set value without expiration (duration = 0)
c.Set(key, value, 0)
// Get value from cache
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
}
func TestInMemoryCache_Expiration(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
duration := 1 * time.Millisecond // Very short duration
duration := 1 * time.Millisecond
// Set value in cache
c.Set(key, value, duration)
// Verify it's initially there
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
// Wait for expiration
time.Sleep(2 * time.Millisecond)
// Try to get expired value
result, found = c.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -86,26 +70,20 @@ func TestInMemoryCache_Expiration(t *testing.T) {
}
func TestInMemoryCache_Delete(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value := "test-value"
duration := 5 * time.Minute
// Set value in cache
c.Set(key, value, duration)
// Verify it's there
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value, result)
// Delete the key
c.Delete(key)
// Verify it's gone
result, found = c.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -114,37 +92,29 @@ func TestInMemoryCache_Delete(t *testing.T) {
}
func TestInMemoryCache_Overwrite(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
key := "test-key"
value1 := "test-value-1"
value2 := "test-value-2"
duration := 5 * time.Minute
// Set first value
c.Set(key, value1, duration)
// Verify first value
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value1, result)
// Overwrite with second value
c.Set(key, value2, duration)
// Verify second value
result, found = c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, value2, result)
}
func TestInMemoryCache_Multiple_Keys(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test data
testData := map[string]string{
"key1": "value1",
"key2": "value2",
@@ -152,22 +122,18 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) {
}
duration := 5 * time.Minute
// Set multiple values
for key, value := range testData {
c.Set(key, value, duration)
}
// Verify all values
for key, expectedValue := range testData {
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, expectedValue, result)
}
// Delete one key
c.Delete("key2")
// Verify key2 is gone but others remain
result, found := c.Get("key2")
tests.AssertEqual(t, false, found)
if result != nil {
@@ -184,10 +150,8 @@ func TestInMemoryCache_Multiple_Keys(t *testing.T) {
}
func TestInMemoryCache_Complex_Objects(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test with complex object (User struct)
user := &model.User{
ID: uuid.New(),
Username: "testuser",
@@ -196,15 +160,12 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) {
key := "user:" + user.ID.String()
duration := 5 * time.Minute
// Set user in cache
c.Set(key, user, duration)
// Get user from cache
result, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, result)
// Verify it's the same user
cachedUser, ok := result.(*model.User)
tests.AssertEqual(t, true, ok)
tests.AssertEqual(t, user.ID, cachedUser.ID)
@@ -212,33 +173,27 @@ func TestInMemoryCache_Complex_Objects(t *testing.T) {
}
func TestInMemoryCache_GetOrSet_CacheHit(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Pre-populate cache
key := "test-key"
expectedValue := "cached-value"
c.Set(key, expectedValue, 5*time.Minute)
// Track if fetcher is called
fetcherCalled := false
fetcher := func() (string, error) {
fetcherCalled = true
return "fetcher-value", nil
}
// Use GetOrSet - should return cached value
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
tests.AssertNoError(t, err)
tests.AssertEqual(t, expectedValue, result)
tests.AssertEqual(t, false, fetcherCalled) // Fetcher should not be called
tests.AssertEqual(t, false, fetcherCalled)
}
func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Track if fetcher is called
fetcherCalled := false
expectedValue := "fetcher-value"
fetcher := func() (string, error) {
@@ -248,35 +203,29 @@ func TestInMemoryCache_GetOrSet_CacheMiss(t *testing.T) {
key := "test-key"
// Use GetOrSet - should call fetcher and cache result
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
tests.AssertNoError(t, err)
tests.AssertEqual(t, expectedValue, result)
tests.AssertEqual(t, true, fetcherCalled) // Fetcher should be called
tests.AssertEqual(t, true, fetcherCalled)
// Verify value is now cached
cachedResult, found := c.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertEqual(t, expectedValue, cachedResult)
}
func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Fetcher that returns error
fetcher := func() (string, error) {
return "", tests.ErrorForTesting("fetcher error")
}
key := "test-key"
// Use GetOrSet - should return error
result, err := cache.GetOrSet(c, key, 5*time.Minute, fetcher)
tests.AssertError(t, err, "")
tests.AssertEqual(t, "", result)
// Verify nothing is cached
cachedResult, found := c.Get(key)
tests.AssertEqual(t, false, found)
if cachedResult != nil {
@@ -285,10 +234,8 @@ func TestInMemoryCache_GetOrSet_FetcherError(t *testing.T) {
}
func TestInMemoryCache_TypeSafety(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test type safety with GetOrSet
userFetcher := func() (*model.User, error) {
return &model.User{
ID: uuid.New(),
@@ -298,13 +245,11 @@ func TestInMemoryCache_TypeSafety(t *testing.T) {
key := "user-key"
// Use GetOrSet with User type
user, err := cache.GetOrSet(c, key, 5*time.Minute, userFetcher)
tests.AssertNoError(t, err)
tests.AssertNotNil(t, user)
tests.AssertEqual(t, "testuser", user.Username)
// Verify correct type is cached
cachedResult, found := c.Get(key)
tests.AssertEqual(t, true, found)
cachedUser, ok := cachedResult.(*model.User)
@@ -313,26 +258,21 @@ func TestInMemoryCache_TypeSafety(t *testing.T) {
}
func TestInMemoryCache_Concurrent_Access(t *testing.T) {
// Setup
c := cache.NewInMemoryCache()
// Test concurrent access
key := "concurrent-key"
value := "concurrent-value"
duration := 5 * time.Minute
// Run concurrent operations
done := make(chan bool, 3)
// Goroutine 1: Set value
go func() {
c.Set(key, value, duration)
done <- true
}()
// Goroutine 2: Get value
go func() {
time.Sleep(1 * time.Millisecond) // Small delay to ensure Set happens first
time.Sleep(1 * time.Millisecond)
result, found := c.Get(key)
if found {
tests.AssertEqual(t, value, result)
@@ -340,19 +280,16 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) {
done <- true
}()
// Goroutine 3: Delete value
go func() {
time.Sleep(2 * time.Millisecond) // Delay to ensure Set and Get happen first
time.Sleep(2 * time.Millisecond)
c.Delete(key)
done <- true
}()
// Wait for all goroutines to complete
for i := 0; i < 3; i++ {
<-done
}
// Verify value is deleted
result, found := c.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -361,7 +298,6 @@ func TestInMemoryCache_Concurrent_Access(t *testing.T) {
}
func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -371,14 +307,12 @@ func TestServerStatusCache_GetStatus_NeedsRefresh(t *testing.T) {
serviceName := "test-service"
// Initial call - should need refresh
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusUnknown, status)
tests.AssertEqual(t, true, needsRefresh)
}
func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -389,17 +323,14 @@ func TestServerStatusCache_UpdateStatus_GetStatus(t *testing.T) {
serviceName := "test-service"
expectedStatus := model.StatusRunning
// Update status
cache.UpdateStatus(serviceName, expectedStatus)
// Get status - should return cached value
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, expectedStatus, status)
tests.AssertEqual(t, false, needsRefresh)
}
func TestServerStatusCache_Throttling(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 100 * time.Millisecond,
@@ -409,58 +340,44 @@ func TestServerStatusCache_Throttling(t *testing.T) {
serviceName := "test-service"
// Update status
cache.UpdateStatus(serviceName, model.StatusRunning)
// Immediate call - should return cached value
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Call within throttle time - should return cached/default status
status, needsRefresh = cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Wait for throttle time to pass
time.Sleep(150 * time.Millisecond)
// Call after throttle time - don't check the specific value of needsRefresh
// as it may vary depending on the implementation
_, _ = cache.GetStatus(serviceName)
// Test passes if we reach this point without errors
}
func TestServerStatusCache_Expiration(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 50 * time.Millisecond, // Very short expiration
ExpirationTime: 50 * time.Millisecond,
ThrottleTime: 10 * time.Millisecond,
DefaultStatus: model.StatusUnknown,
}
cache := model.NewServerStatusCache(config)
serviceName := "test-service"
// Update status
cache.UpdateStatus(serviceName, model.StatusRunning)
// Immediate call - should return cached value
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Wait for expiration
time.Sleep(60 * time.Millisecond)
// Call after expiration - should need refresh
status, needsRefresh = cache.GetStatus(serviceName)
tests.AssertEqual(t, true, needsRefresh)
}
func TestServerStatusCache_InvalidateStatus(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -470,25 +387,20 @@ func TestServerStatusCache_InvalidateStatus(t *testing.T) {
serviceName := "test-service"
// Update status
cache.UpdateStatus(serviceName, model.StatusRunning)
// Verify it's cached
status, needsRefresh := cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
// Invalidate status
cache.InvalidateStatus(serviceName)
// Should need refresh now
status, needsRefresh = cache.GetStatus(serviceName)
tests.AssertEqual(t, model.StatusUnknown, status)
tests.AssertEqual(t, true, needsRefresh)
}
func TestServerStatusCache_Clear(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -496,23 +408,19 @@ func TestServerStatusCache_Clear(t *testing.T) {
}
cache := model.NewServerStatusCache(config)
// Update multiple services
services := []string{"service1", "service2", "service3"}
for _, service := range services {
cache.UpdateStatus(service, model.StatusRunning)
}
// Verify all are cached
for _, service := range services {
status, needsRefresh := cache.GetStatus(service)
tests.AssertEqual(t, model.StatusRunning, status)
tests.AssertEqual(t, false, needsRefresh)
}
// Clear cache
cache.Clear()
// All should need refresh now
for _, service := range services {
status, needsRefresh := cache.GetStatus(service)
tests.AssertEqual(t, model.StatusUnknown, status)
@@ -521,30 +429,23 @@ func TestServerStatusCache_Clear(t *testing.T) {
}
func TestLookupCache_SetGetClear(t *testing.T) {
// Setup
cache := model.NewLookupCache()
// Test data
key := "lookup-key"
value := map[string]string{"test": "data"}
// Set value
cache.Set(key, value)
// Get value
result, found := cache.Get(key)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, result)
// Verify it's the same data
resultMap, ok := result.(map[string]string)
tests.AssertEqual(t, true, ok)
tests.AssertEqual(t, "data", resultMap["test"])
// Clear cache
cache.Clear()
// Should be gone now
result, found = cache.Get(key)
tests.AssertEqual(t, false, found)
if result != nil {
@@ -553,7 +454,6 @@ func TestLookupCache_SetGetClear(t *testing.T) {
}
func TestServerConfigCache_Configuration(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -571,17 +471,14 @@ func TestServerConfigCache_Configuration(t *testing.T) {
ConfigVersion: model.IntString(1),
}
// Initial get - should miss
result, found := cache.GetConfiguration(serverID)
tests.AssertEqual(t, false, found)
if result != nil {
t.Fatal("Expected nil result, got non-nil")
}
// Update cache
cache.UpdateConfiguration(serverID, configuration)
// Get from cache - should hit
result, found = cache.GetConfiguration(serverID)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, result)
@@ -590,7 +487,6 @@ func TestServerConfigCache_Configuration(t *testing.T) {
}
func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
// Setup
config := model.CacheConfig{
ExpirationTime: 5 * time.Minute,
ThrottleTime: 1 * time.Second,
@@ -602,11 +498,9 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
configuration := model.Configuration{UdpPort: model.IntString(9231)}
assistRules := model.AssistRules{StabilityControlLevelMax: model.IntString(0)}
// Update multiple configs for server
cache.UpdateConfiguration(serverID, configuration)
cache.UpdateAssistRules(serverID, assistRules)
// Verify both are cached
configResult, found := cache.GetConfiguration(serverID)
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, configResult)
@@ -615,10 +509,8 @@ func TestServerConfigCache_InvalidateServerCache(t *testing.T) {
tests.AssertEqual(t, true, found)
tests.AssertNotNil(t, assistResult)
// Invalidate server cache
cache.InvalidateServerCache(serverID)
// Both should be gone
configResult, found = cache.GetConfiguration(serverID)
tests.AssertEqual(t, false, found)
if configResult != nil {

Some files were not shown because too many files have changed in this diff Show More