Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4004d83411 | ||
|
|
901dbe697e | ||
|
|
760412d7db | ||
|
|
4ab94de529 | ||
|
|
b3f89593fb | ||
|
|
2a863c51e9 | ||
|
|
a70d923a6a | ||
|
|
f660511b63 | ||
|
|
044af60699 | ||
|
|
384036bcdd | ||
|
|
ef300d233b | ||
|
|
edad65d6a9 | ||
|
|
486c972bba | ||
|
|
aab5d2ad61 | ||
|
|
1683d5c2f1 | ||
|
|
87d4af0bec | ||
|
|
35449a090d | ||
|
|
5324a41e05 | ||
|
|
ac61ba5223 | ||
|
|
56c51e5d02 | ||
|
|
1c57da9aba | ||
|
|
b2d88f1aa3 | ||
|
|
45d9681203 |
@@ -10,6 +10,7 @@ env:
|
||||
MIGRATE_BINARY: "acc-server-migration"
|
||||
DEPLOY_PATH: 'C:\acc-server-manager'
|
||||
SERVICE_NAME: "ACC Server Manager"
|
||||
HEALTH_URL: "http://localhost:4000/v1/system/health"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -135,7 +136,7 @@ jobs:
|
||||
|
||||
while ($attempt -le $maxAttempts -and -not $success) {
|
||||
try {
|
||||
$response = Invoke-WebRequest -Uri "http://localhost:8080/health" -TimeoutSec 5
|
||||
$response = Invoke-WebRequest -Uri "${{ env.HEALTH_URL }}" -TimeoutSec 5
|
||||
if ($response.StatusCode -eq 200) {
|
||||
Write-Host "Health check passed!"
|
||||
$success = $true
|
||||
@@ -150,31 +151,3 @@ jobs:
|
||||
if (-not $success) {
|
||||
throw "Health check failed after $maxAttempts attempts"
|
||||
}
|
||||
|
||||
- name: Notify on success
|
||||
if: success()
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const { repo, owner } = context.repo;
|
||||
const release = context.ref.replace('refs/tags/', '');
|
||||
await github.rest.issues.createComment({
|
||||
owner,
|
||||
repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `✅ Successfully deployed ${release} to production!`
|
||||
});
|
||||
|
||||
- name: Notify on failure
|
||||
if: failure()
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const { repo, owner } = context.repo;
|
||||
const release = context.ref.replace('refs/tags/', '');
|
||||
await github.rest.issues.createComment({
|
||||
owner,
|
||||
repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `❌ Failed to deploy ${release} to production. Check the workflow logs for details.`
|
||||
});
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
func main() {
|
||||
configs.Init()
|
||||
jwt.Init()
|
||||
// Initialize new logging system
|
||||
if err := logging.InitializeLogging(); err != nil {
|
||||
fmt.Printf("Failed to initialize logging system: %v\n", err)
|
||||
@@ -37,6 +36,8 @@ func main() {
|
||||
logging.InfoStartup("APPLICATION", "ACC Server Manager starting up")
|
||||
|
||||
di := dig.New()
|
||||
di.Provide(func() *jwt.JWTHandler { return jwt.NewJWTHandler(os.Getenv("JWT_SECRET")) })
|
||||
di.Provide(func() *jwt.OpenJWTHandler { return jwt.NewOpenJWTHandler(os.Getenv("JWT_SECRET_OPEN")) })
|
||||
cache.Start(di)
|
||||
db.Start(di)
|
||||
server.Start(di)
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
// @description API for managing Assetto Corsa Competizione dedicated servers
|
||||
//
|
||||
// @contact.name ACC Server Manager Support
|
||||
// @contact.url https://github.com/yourusername/acc-server-manager
|
||||
// @contact.url https://github.com/FJurmanovic/acc-server-manager
|
||||
//
|
||||
// @license.name MIT
|
||||
// @license.url https://opensource.org/licenses/MIT
|
||||
//
|
||||
// @host localhost:3000
|
||||
// @BasePath /api/v1
|
||||
// @schemes http https
|
||||
// @host acc-api.jurmanovic.com
|
||||
// @BasePath /v1
|
||||
// @schemes https
|
||||
//
|
||||
// @securityDefinitions.apikey BearerAuth
|
||||
// @in header
|
||||
|
||||
93
cmd/steam-crypt/main.go
Normal file
93
cmd/steam-crypt/main.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/configs"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
encrypt = flag.Bool("encrypt", false, "Encrypt a password")
|
||||
decrypt = flag.Bool("decrypt", false, "Decrypt a password")
|
||||
password = flag.String("password", "", "Password to encrypt/decrypt")
|
||||
help = flag.Bool("help", false, "Show help")
|
||||
)
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *help || (!*encrypt && !*decrypt) {
|
||||
showHelp()
|
||||
return
|
||||
}
|
||||
|
||||
if *encrypt && *decrypt {
|
||||
fmt.Fprintf(os.Stderr, "Error: Cannot specify both -encrypt and -decrypt\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *password == "" {
|
||||
fmt.Fprintf(os.Stderr, "Error: Password is required\n")
|
||||
showHelp()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Initialize configs to load encryption key
|
||||
configs.Init()
|
||||
|
||||
if *encrypt {
|
||||
encrypted, err := model.EncryptPassword(*password)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error encrypting password: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println(encrypted)
|
||||
}
|
||||
|
||||
if *decrypt {
|
||||
decrypted, err := model.DecryptPassword(*password)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error decrypting password: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println(decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func showHelp() {
|
||||
fmt.Println("Steam Credentials Encryption/Decryption Utility")
|
||||
fmt.Println()
|
||||
fmt.Println("This utility encrypts and decrypts Steam credentials using the same")
|
||||
fmt.Println("AES-256-GCM encryption used by the ACC Server Manager application.")
|
||||
fmt.Println()
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" steam-crypt -encrypt -password \"your_password\"")
|
||||
fmt.Println(" steam-crypt -decrypt -password \"encrypted_string\"")
|
||||
fmt.Println()
|
||||
fmt.Println("Options:")
|
||||
fmt.Println(" -encrypt Encrypt the provided password")
|
||||
fmt.Println(" -decrypt Decrypt the provided encrypted string")
|
||||
fmt.Println(" -password The password to encrypt or encrypted string to decrypt")
|
||||
fmt.Println(" -help Show this help message")
|
||||
fmt.Println()
|
||||
fmt.Println("Environment Variables Required:")
|
||||
fmt.Println(" ENCRYPTION_KEY - 32-byte encryption key (same as main application)")
|
||||
fmt.Println(" APP_SECRET - Application secret (required by configs)")
|
||||
fmt.Println(" APP_SECRET_CODE - Application secret code (required by configs)")
|
||||
fmt.Println(" ACCESS_KEY - Access key (required by configs)")
|
||||
fmt.Println()
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" # Encrypt a password")
|
||||
fmt.Println(" steam-crypt -encrypt -password \"mysteampassword\"")
|
||||
fmt.Println()
|
||||
fmt.Println(" # Decrypt an encrypted password")
|
||||
fmt.Println(" steam-crypt -decrypt -password \"base64encryptedstring\"")
|
||||
fmt.Println()
|
||||
fmt.Println("Security Notes:")
|
||||
fmt.Println(" - The encryption key must be exactly 32 bytes for AES-256")
|
||||
fmt.Println(" - Uses AES-256-GCM for authenticated encryption")
|
||||
fmt.Println(" - Each encryption includes a unique nonce for security")
|
||||
fmt.Println(" - Passwords are validated for length and basic security")
|
||||
}
|
||||
243
docs/STEAM_2FA_IMPLEMENTATION.md
Normal file
243
docs/STEAM_2FA_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,243 @@
|
||||
# Steam 2FA Implementation Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of Steam Two-Factor Authentication (2FA) support for the ACC Server Manager. When SteamCMD requires 2FA confirmation during server installation or updates, the system now signals the frontend and waits for user confirmation before proceeding.
|
||||
|
||||
## Architecture
|
||||
|
||||
The 2FA implementation consists of several interconnected components:
|
||||
|
||||
### Backend Components
|
||||
|
||||
1. **Steam2FAManager** (`local/model/steam_2fa.go`)
|
||||
- Thread-safe management of 2FA requests
|
||||
- Request lifecycle tracking (pending → complete/error)
|
||||
- Channel-based waiting mechanism for synchronization
|
||||
|
||||
2. **InteractiveCommandExecutor** (`local/utl/command/interactive_executor.go`)
|
||||
- Monitors SteamCMD output for 2FA prompts
|
||||
- Creates 2FA requests when prompts are detected
|
||||
- Waits for user confirmation before proceeding
|
||||
|
||||
3. **Steam2FAController** (`local/controller/steam_2fa.go`)
|
||||
- REST API endpoints for 2FA management
|
||||
- Handles frontend requests to complete/cancel 2FA
|
||||
|
||||
4. **Updated SteamService** (`local/service/steam_service.go`)
|
||||
- Uses InteractiveCommandExecutor for SteamCMD operations
|
||||
- Passes server context to 2FA requests
|
||||
|
||||
### Frontend Components
|
||||
|
||||
1. **Steam2FA Store** (`src/stores/steam2fa.ts`)
|
||||
- Svelte store for managing 2FA state
|
||||
- Automatic polling for pending requests
|
||||
- API communication methods
|
||||
|
||||
2. **Steam2FANotification Component** (`src/components/Steam2FANotification.svelte`)
|
||||
- Modal UI for 2FA confirmation
|
||||
- Automatic display when requests are pending
|
||||
- User interaction handling
|
||||
|
||||
3. **Type Definitions** (`src/models/steam2fa.ts`)
|
||||
- TypeScript interfaces for 2FA data structures
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### GET /v1/steam2fa/pending
|
||||
Returns all pending 2FA requests.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "uuid-string",
|
||||
"status": "pending",
|
||||
"message": "Steam Guard prompt message",
|
||||
"requestTime": "2024-01-01T12:00:00Z",
|
||||
"serverId": "server-uuid"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### GET /v1/steam2fa/{id}
|
||||
Returns a specific 2FA request by ID.
|
||||
|
||||
### POST /v1/steam2fa/{id}/complete
|
||||
Marks a 2FA request as completed, allowing SteamCMD to proceed.
|
||||
|
||||
### POST /v1/steam2fa/{id}/cancel
|
||||
Cancels a 2FA request, causing the SteamCMD operation to fail.
|
||||
|
||||
## Flow Diagram
|
||||
|
||||
```
|
||||
SteamCMD Operation
|
||||
↓
|
||||
InteractiveCommandExecutor monitors output
|
||||
↓
|
||||
2FA prompt detected
|
||||
↓
|
||||
Steam2FARequest created
|
||||
↓
|
||||
Frontend polls and detects request
|
||||
↓
|
||||
Modal appears for user
|
||||
↓
|
||||
User confirms in Steam Mobile App
|
||||
↓
|
||||
User clicks "I've Confirmed"
|
||||
↓
|
||||
API call to complete request
|
||||
↓
|
||||
SteamCMD operation continues
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Backend Configuration
|
||||
|
||||
The system uses existing configuration patterns. No additional environment variables are required.
|
||||
|
||||
### Frontend Configuration
|
||||
|
||||
The API base URL is automatically configured as `/v1` to match the backend prefix.
|
||||
|
||||
Polling interval is set to 5 seconds by default and can be modified in `steam2fa.ts`:
|
||||
|
||||
```typescript
|
||||
const POLLING_INTERVAL = 5000; // milliseconds
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Authentication Required**: All 2FA endpoints require user authentication
|
||||
2. **Permission-Based Access**: Uses existing `ServerView` and `ServerUpdate` permissions
|
||||
3. **Request Cleanup**: Automatic cleanup of old requests (30 minutes) prevents memory leaks
|
||||
4. **No Sensitive Data**: No Steam credentials are exposed through the 2FA system
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Backend Error Handling
|
||||
- Timeouts after 5 minutes if no user response
|
||||
- Proper error propagation to calling services
|
||||
- Comprehensive logging for debugging
|
||||
|
||||
### Frontend Error Handling
|
||||
- Network error handling with user feedback
|
||||
- Automatic retry mechanisms
|
||||
- Graceful degradation when API is unavailable
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
### For Developers
|
||||
|
||||
1. **Adding New 2FA Prompts**: Extend the `is2FAPrompt` function in `interactive_executor.go`
|
||||
2. **Customizing Timeouts**: Modify the timeout duration in `handle2FAPrompt`
|
||||
3. **UI Customization**: Modify the `Steam2FANotification.svelte` component
|
||||
|
||||
### For Users
|
||||
|
||||
1. When creating or updating a server, watch for the 2FA notification
|
||||
2. Check your Steam Mobile App when prompted
|
||||
3. Confirm the login request in the Steam app
|
||||
4. Click "I've Confirmed" in the web interface
|
||||
5. The server operation will continue automatically
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### Backend Logs
|
||||
The system logs important events:
|
||||
- 2FA prompt detection
|
||||
- Request creation and completion
|
||||
- Timeout events
|
||||
- Error conditions
|
||||
|
||||
Search for log entries containing:
|
||||
- `2FA prompt detected`
|
||||
- `Created 2FA request`
|
||||
- `2FA completed successfully`
|
||||
- `2FA completion failed`
|
||||
|
||||
### Frontend Debugging
|
||||
The Steam2FA store provides debugging information:
|
||||
- `$steam2fa.error` - Current error state
|
||||
- `$steam2fa.isLoading` - Loading state
|
||||
- `$steam2fa.lastChecked` - Last polling timestamp
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Polling Frequency**: 5-second polling provides good responsiveness without excessive load
|
||||
2. **Request Cleanup**: Automatic cleanup prevents memory accumulation
|
||||
3. **Efficient UI Updates**: Reactive Svelte stores minimize unnecessary re-renders
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Single User Sessions**: Currently designed for single-user scenarios
|
||||
2. **Steam Mobile App Required**: Users must have Steam Mobile App installed
|
||||
3. **Manual Confirmation**: No automatic 2FA code input support
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. **WebSocket Support**: Real-time communication instead of polling
|
||||
2. **Multiple User Support**: Handle multiple simultaneous 2FA requests
|
||||
3. **Enhanced Prompt Detection**: More sophisticated Steam output parsing
|
||||
4. **Notification System**: Browser notifications for 2FA requests
|
||||
|
||||
## Testing
|
||||
|
||||
### Manual Testing
|
||||
1. Create a new server to trigger SteamCMD
|
||||
2. Ensure Steam account has 2FA enabled
|
||||
3. Verify modal appears when 2FA is required
|
||||
4. Test both "confirm" and "cancel" workflows
|
||||
|
||||
### Automated Testing
|
||||
The system includes comprehensive error handling but manual testing is recommended for 2FA workflows due to the interactive nature.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Modal doesn't appear**
|
||||
- Check browser console for errors
|
||||
- Verify API connectivity
|
||||
- Ensure user has proper permissions
|
||||
|
||||
2. **SteamCMD hangs**
|
||||
- Check if 2FA request was created (backend logs)
|
||||
- Verify Steam Mobile App connectivity
|
||||
- Check for timeout errors
|
||||
|
||||
3. **API errors**
|
||||
- Verify user authentication
|
||||
- Check server permissions
|
||||
- Review backend error logs
|
||||
|
||||
### Debug Commands
|
||||
|
||||
```bash
|
||||
# Check backend logs for 2FA events
|
||||
grep -i "2fa" logs/app.log
|
||||
|
||||
# Monitor API requests
|
||||
tail -f logs/app.log | grep "steam2fa"
|
||||
```
|
||||
|
||||
## Version History
|
||||
|
||||
- **v1.0.0**: Initial implementation with polling-based frontend and REST API
|
||||
- Added comprehensive error handling and logging
|
||||
- Implemented automatic request cleanup
|
||||
- Added responsive UI components
|
||||
|
||||
## Contributing
|
||||
|
||||
When contributing to the 2FA system:
|
||||
|
||||
1. Follow existing error handling patterns
|
||||
2. Add comprehensive logging for new features
|
||||
3. Update this documentation for any API changes
|
||||
4. Test with actual Steam 2FA scenarios
|
||||
5. Consider security implications of any changes
|
||||
3
go.mod
3
go.mod
@@ -21,10 +21,12 @@ require (
|
||||
require (
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||
github.com/fasthttp/websocket v1.5.3 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/jsonreference v0.21.0 // indirect
|
||||
github.com/go-openapi/spec v0.21.0 // indirect
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
github.com/gofiber/websocket/v2 v2.2.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
@@ -35,6 +37,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
|
||||
github.com/swaggo/files/v2 v2.0.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -4,6 +4,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1
|
||||
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQtdek=
|
||||
github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs=
|
||||
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
|
||||
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
|
||||
github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ=
|
||||
@@ -16,6 +18,8 @@ github.com/gofiber/fiber/v2 v2.52.8 h1:xl4jJQ0BV5EJTA2aWiKw/VddRpHrKeZLF0QPUxqn0
|
||||
github.com/gofiber/fiber/v2 v2.52.8/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||
github.com/gofiber/swagger v1.1.0 h1:ff3rg1fB+Rp5JN/N8jfxTiZtMKe/9tB9QDc79fPiJKQ=
|
||||
github.com/gofiber/swagger v1.1.0/go.mod h1:pRZL0Np35sd+lTODTE5The0G+TMHfNY+oC4hM2/i5m8=
|
||||
github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w=
|
||||
github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -53,6 +57,8 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/swaggo/files/v2 v2.0.0 h1:hmAt8Dkynw7Ssz46F6pn8ok6YmGZqHSVLZ+HQM7i0kw=
|
||||
|
||||
@@ -31,6 +31,7 @@ func Init(di *dig.Container, app *fiber.App) {
|
||||
StateHistory: serverIdGroup.Group("/state-history"),
|
||||
Membership: groups.Group("/membership"),
|
||||
System: groups.Group("/system"),
|
||||
WebSocket: groups.Group("/ws"),
|
||||
}
|
||||
|
||||
accessKeyMiddleware := middleware.NewAccessKeyMiddleware()
|
||||
|
||||
@@ -58,7 +58,7 @@ func NewConfigController(as *service.ConfigService, routeGroups *common.RouteGro
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Server or config file not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server/{id}/config/{file} [put]
|
||||
// @Router /server/{id}/config/{file} [put]
|
||||
func (ac *ConfigController) UpdateConfig(c *fiber.Ctx) error {
|
||||
restart := c.QueryBool("restart")
|
||||
serverID := c.Params("id")
|
||||
@@ -106,7 +106,7 @@ func (ac *ConfigController) UpdateConfig(c *fiber.Ctx) error {
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Server or config file not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server/{id}/config/{file} [get]
|
||||
// @Router /server/{id}/config/{file} [get]
|
||||
func (ac *ConfigController) GetConfig(c *fiber.Ctx) error {
|
||||
Model, err := ac.service.GetConfig(c)
|
||||
if err != nil {
|
||||
@@ -130,7 +130,7 @@ func (ac *ConfigController) GetConfig(c *fiber.Ctx) error {
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server/{id}/config [get]
|
||||
// @Router /server/{id}/config [get]
|
||||
func (ac *ConfigController) GetConfigs(c *fiber.Ctx) error {
|
||||
Model, err := ac.service.GetConfigs(c)
|
||||
if err != nil {
|
||||
|
||||
@@ -20,7 +20,12 @@ func InitializeControllers(c *dig.Container) {
|
||||
logging.Panic("unable to initialize auth middleware")
|
||||
}
|
||||
|
||||
err := c.Invoke(NewServiceControlController)
|
||||
err := c.Invoke(NewSystemController)
|
||||
if err != nil {
|
||||
logging.Panic("unable to initialize system controller")
|
||||
}
|
||||
|
||||
err = c.Invoke(NewServiceControlController)
|
||||
if err != nil {
|
||||
logging.Panic("unable to initialize service control controller")
|
||||
}
|
||||
@@ -49,4 +54,9 @@ func InitializeControllers(c *dig.Container) {
|
||||
if err != nil {
|
||||
logging.Panic("unable to initialize membership controller")
|
||||
}
|
||||
|
||||
err = c.Invoke(NewWebSocketController)
|
||||
if err != nil {
|
||||
logging.Panic("unable to initialize websocket controller")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func NewLookupController(as *service.LookupService, routeGroups *common.RouteGro
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/lookup/tracks [get]
|
||||
// @Router /lookup/tracks [get]
|
||||
func (ac *LookupController) GetTracks(c *fiber.Ctx) error {
|
||||
result, err := ac.service.GetTracks(c)
|
||||
if err != nil {
|
||||
@@ -66,7 +66,7 @@ func (ac *LookupController) GetTracks(c *fiber.Ctx) error {
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/lookup/car-models [get]
|
||||
// @Router /lookup/car-models [get]
|
||||
func (ac *LookupController) GetCarModels(c *fiber.Ctx) error {
|
||||
result, err := ac.service.GetCarModels(c)
|
||||
if err != nil {
|
||||
@@ -86,7 +86,7 @@ func (ac *LookupController) GetCarModels(c *fiber.Ctx) error {
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/lookup/driver-categories [get]
|
||||
// @Router /lookup/driver-categories [get]
|
||||
func (ac *LookupController) GetDriverCategories(c *fiber.Ctx) error {
|
||||
result, err := ac.service.GetDriverCategories(c)
|
||||
if err != nil {
|
||||
@@ -106,7 +106,7 @@ func (ac *LookupController) GetDriverCategories(c *fiber.Ctx) error {
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/lookup/cup-categories [get]
|
||||
// @Router /lookup/cup-categories [get]
|
||||
func (ac *LookupController) GetCupCategories(c *fiber.Ctx) error {
|
||||
result, err := ac.service.GetCupCategories(c)
|
||||
if err != nil {
|
||||
@@ -126,7 +126,7 @@ func (ac *LookupController) GetCupCategories(c *fiber.Ctx) error {
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/lookup/session-types [get]
|
||||
// @Router /lookup/session-types [get]
|
||||
func (ac *LookupController) GetSessionTypes(c *fiber.Ctx) error {
|
||||
result, err := ac.service.GetSessionTypes(c)
|
||||
if err != nil {
|
||||
|
||||
@@ -34,6 +34,7 @@ func NewMembershipController(service *service.MembershipService, auth *middlewar
|
||||
}
|
||||
|
||||
routeGroups.Auth.Post("/login", mc.Login)
|
||||
routeGroups.Auth.Post("/open-token", mc.auth.Authenticate, mc.GenerateOpenToken)
|
||||
|
||||
usersGroup := routeGroups.Membership
|
||||
usersGroup.Use(mc.auth.Authenticate)
|
||||
@@ -82,6 +83,26 @@ func (c *MembershipController) Login(ctx *fiber.Ctx) error {
|
||||
return ctx.JSON(fiber.Map{"token": token})
|
||||
}
|
||||
|
||||
// GenerateOpenToken generates an open token for a user.
|
||||
// @Summary Generate an open token
|
||||
// @Description Generate an open token for a user
|
||||
// @Tags Authentication
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} object{token=string} "JWT token"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid request body"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Invalid credentials"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Router /auth/open-token [post]
|
||||
func (c *MembershipController) GenerateOpenToken(ctx *fiber.Ctx) error {
|
||||
token, err := c.service.GenerateOpenToken(ctx.UserContext(), ctx.Locals("userID").(string))
|
||||
if err != nil {
|
||||
return c.errorHandler.HandleAuthError(ctx, err)
|
||||
}
|
||||
|
||||
return ctx.JSON(fiber.Map{"token": token})
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
// @Summary Create a new user
|
||||
// @Description Create a new user account with specified role
|
||||
@@ -139,6 +160,18 @@ func (mc *MembershipController) ListUsers(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// GetUser gets a single user by ID.
|
||||
// @Summary Get user by ID
|
||||
// @Description Get detailed information about a specific user
|
||||
// @Tags User Management
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "User ID (UUID format)"
|
||||
// @Success 200 {object} model.User "User details"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid user ID format"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
|
||||
// @Security BearerAuth
|
||||
// @Router /membership/{id} [get]
|
||||
func (mc *MembershipController) GetUser(c *fiber.Ctx) error {
|
||||
id, err := uuid.Parse(c.Params("id"))
|
||||
if err != nil {
|
||||
@@ -154,6 +187,16 @@ func (mc *MembershipController) GetUser(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// GetMe returns the currently authenticated user's details.
|
||||
// @Summary Get current user details
|
||||
// @Description Get details of the currently authenticated user
|
||||
// @Tags Authentication
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.User "Current user details"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
|
||||
// @Security BearerAuth
|
||||
// @Router /auth/me [get]
|
||||
func (mc *MembershipController) GetMe(c *fiber.Ctx) error {
|
||||
userID, ok := c.Locals("userID").(string)
|
||||
if !ok || userID == "" {
|
||||
@@ -172,6 +215,19 @@ func (mc *MembershipController) GetMe(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user.
|
||||
// @Summary Delete user
|
||||
// @Description Delete a specific user by ID
|
||||
// @Tags User Management
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "User ID (UUID format)"
|
||||
// @Success 204 "User successfully deleted"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid user ID format"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
|
||||
// @Security BearerAuth
|
||||
// @Router /membership/{id} [delete]
|
||||
func (mc *MembershipController) DeleteUser(c *fiber.Ctx) error {
|
||||
id, err := uuid.Parse(c.Params("id"))
|
||||
if err != nil {
|
||||
@@ -187,6 +243,20 @@ func (mc *MembershipController) DeleteUser(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// UpdateUser updates a user.
|
||||
// @Summary Update user
|
||||
// @Description Update user details by ID
|
||||
// @Tags User Management
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "User ID (UUID format)"
|
||||
// @Param user body service.UpdateUserRequest true "Updated user details"
|
||||
// @Success 200 {object} model.User "Updated user details"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid request body or ID format"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "User not found"
|
||||
// @Security BearerAuth
|
||||
// @Router /membership/{id} [put]
|
||||
func (mc *MembershipController) UpdateUser(c *fiber.Ctx) error {
|
||||
id, err := uuid.Parse(c.Params("id"))
|
||||
if err != nil {
|
||||
@@ -207,6 +277,17 @@ func (mc *MembershipController) UpdateUser(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// GetRoles returns all available roles.
|
||||
// @Summary Get all roles
|
||||
// @Description Get a list of all available user roles
|
||||
// @Tags User Management
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {array} model.Role "List of roles"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /membership/roles [get]
|
||||
func (mc *MembershipController) GetRoles(c *fiber.Ctx) error {
|
||||
roles, err := mc.service.GetAllRoles(c.UserContext())
|
||||
if err != nil {
|
||||
|
||||
@@ -29,7 +29,6 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro
|
||||
serverRoutes.Get("/", auth.HasPermission(model.ServerView), ac.GetAll)
|
||||
serverRoutes.Get("/:id", auth.HasPermission(model.ServerView), ac.GetById)
|
||||
serverRoutes.Post("/", auth.HasPermission(model.ServerCreate), ac.CreateServer)
|
||||
serverRoutes.Put("/:id", auth.HasPermission(model.ServerUpdate), ac.UpdateServer)
|
||||
serverRoutes.Delete("/:id", auth.HasPermission(model.ServerDelete), ac.DeleteServer)
|
||||
|
||||
apiServerRoutes := routeGroups.Api.Group("/server")
|
||||
@@ -50,7 +49,7 @@ func NewServerController(ss *service.ServerService, routeGroups *common.RouteGro
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /server [get]
|
||||
// @Router /api/server [get]
|
||||
func (ac *ServerController) GetAllApi(c *fiber.Ctx) error {
|
||||
var filter model.ServerFilter
|
||||
if err := common.ParseQueryFilter(c, &filter); err != nil {
|
||||
@@ -79,7 +78,7 @@ func (ac *ServerController) GetAllApi(c *fiber.Ctx) error {
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server [get]
|
||||
// @Router /server [get]
|
||||
func (ac *ServerController) GetAll(c *fiber.Ctx) error {
|
||||
var filter model.ServerFilter
|
||||
if err := common.ParseQueryFilter(c, &filter); err != nil {
|
||||
@@ -105,7 +104,7 @@ func (ac *ServerController) GetAll(c *fiber.Ctx) error {
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server/{id} [get]
|
||||
// @Router /server/{id} [get]
|
||||
func (ac *ServerController) GetById(c *fiber.Ctx) error {
|
||||
serverIDStr := c.Params("id")
|
||||
serverID, err := uuid.Parse(serverIDStr)
|
||||
@@ -133,55 +132,40 @@ func (ac *ServerController) GetById(c *fiber.Ctx) error {
|
||||
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server [post]
|
||||
// @Router /server [post]
|
||||
func (ac *ServerController) CreateServer(c *fiber.Ctx) error {
|
||||
server := new(model.Server)
|
||||
if err := c.BodyParser(server); err != nil {
|
||||
return ac.errorHandler.HandleParsingError(c, err)
|
||||
}
|
||||
ac.service.GenerateServerPath(server)
|
||||
if err := ac.service.CreateServer(c, server); err != nil {
|
||||
|
||||
server.GenerateUUID()
|
||||
|
||||
// Use async server creation to avoid blocking other requests
|
||||
if err := ac.service.CreateServerAsync(c, server); err != nil {
|
||||
return ac.errorHandler.HandleServiceError(c, err)
|
||||
}
|
||||
|
||||
// Return immediately with server details
|
||||
// The actual creation will happen in the background with WebSocket updates
|
||||
return c.JSON(server)
|
||||
}
|
||||
|
||||
// UpdateServer updates an existing server
|
||||
// @Summary Update an ACC server
|
||||
// @Description Update configuration for an existing ACC server
|
||||
// DeleteServer deletes an existing server
|
||||
// @Summary Delete an ACC server
|
||||
// @Description Delete an existing ACC server
|
||||
// @Tags Server
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Server ID (UUID format)"
|
||||
// @Param server body model.Server true "Updated server configuration"
|
||||
// @Success 200 {object} object "Updated server details"
|
||||
// @Success 200 {object} object "Deleted server details"
|
||||
// @Failure 400 {object} error_handler.ErrorResponse "Invalid server data or ID"
|
||||
// @Failure 401 {object} error_handler.ErrorResponse "Unauthorized"
|
||||
// @Failure 403 {object} error_handler.ErrorResponse "Insufficient permissions"
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Server not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/server/{id} [put]
|
||||
func (ac *ServerController) UpdateServer(c *fiber.Ctx) error {
|
||||
serverIDStr := c.Params("id")
|
||||
serverID, err := uuid.Parse(serverIDStr)
|
||||
if err != nil {
|
||||
return ac.errorHandler.HandleUUIDError(c, "server ID")
|
||||
}
|
||||
|
||||
server := new(model.Server)
|
||||
if err := c.BodyParser(server); err != nil {
|
||||
return ac.errorHandler.HandleParsingError(c, err)
|
||||
}
|
||||
server.ID = serverID
|
||||
|
||||
if err := ac.service.UpdateServer(c, server); err != nil {
|
||||
return ac.errorHandler.HandleServiceError(c, err)
|
||||
}
|
||||
return c.JSON(server)
|
||||
}
|
||||
|
||||
// DeleteServer deletes a server
|
||||
// @Router /server/{id} [delete]
|
||||
func (ac *ServerController) DeleteServer(c *fiber.Ctx) error {
|
||||
serverIDStr := c.Params("id")
|
||||
serverID, err := uuid.Parse(serverIDStr)
|
||||
|
||||
@@ -51,7 +51,7 @@ func NewServiceControlController(as *service.ServiceControlService, routeGroups
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Service not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/service-control/{service} [get]
|
||||
// @Router /server/{id}/service/{service} [get]
|
||||
func (ac *ServiceControlController) getStatus(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
c.Locals("serverId", id)
|
||||
@@ -78,7 +78,7 @@ func (ac *ServiceControlController) getStatus(c *fiber.Ctx) error {
|
||||
// @Failure 409 {object} error_handler.ErrorResponse "Service already running"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/service-control/start [post]
|
||||
// @Router /server/{id}/service/start [post]
|
||||
func (ac *ServiceControlController) startServer(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
c.Locals("serverId", id)
|
||||
@@ -105,7 +105,7 @@ func (ac *ServiceControlController) startServer(c *fiber.Ctx) error {
|
||||
// @Failure 409 {object} error_handler.ErrorResponse "Service already stopped"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/service-control/stop [post]
|
||||
// @Router /server/{id}/service/stop [post]
|
||||
func (ac *ServiceControlController) stopServer(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
c.Locals("serverId", id)
|
||||
@@ -131,7 +131,7 @@ func (ac *ServiceControlController) stopServer(c *fiber.Ctx) error {
|
||||
// @Failure 404 {object} error_handler.ErrorResponse "Service not found"
|
||||
// @Failure 500 {object} error_handler.ErrorResponse "Internal server error"
|
||||
// @Security BearerAuth
|
||||
// @Router /v1/service-control/restart [post]
|
||||
// @Router /server/{id}/service/restart [post]
|
||||
func (ac *ServiceControlController) restartServer(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
c.Locals("serverId", id)
|
||||
|
||||
@@ -42,7 +42,7 @@ func NewStateHistoryController(as *service.StateHistoryService, routeGroups *com
|
||||
// @Description Return StateHistorys
|
||||
// @Tags StateHistory
|
||||
// @Success 200 {array} string
|
||||
// @Router /v1/state-history [get]
|
||||
// @Router /state-history [get]
|
||||
func (ac *StateHistoryController) GetAll(c *fiber.Ctx) error {
|
||||
var filter model.StateHistoryFilter
|
||||
if err := common.ParseQueryFilter(c, &filter); err != nil {
|
||||
@@ -63,7 +63,7 @@ func (ac *StateHistoryController) GetAll(c *fiber.Ctx) error {
|
||||
// @Description Return StateHistorys
|
||||
// @Tags StateHistory
|
||||
// @Success 200 {array} string
|
||||
// @Router /v1/state-history/statistics [get]
|
||||
// @Router /state-history/statistics [get]
|
||||
func (ac *StateHistoryController) GetStatistics(c *fiber.Ctx) error {
|
||||
var filter model.StateHistoryFilter
|
||||
if err := common.ParseQueryFilter(c, &filter); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/common"
|
||||
"acc-server-manager/local/utl/configs"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
@@ -30,9 +31,9 @@ func NewSystemController(routeGroups *common.RouteGroups) *SystemController {
|
||||
//
|
||||
// @Summary Return service control status
|
||||
// @Description Return service control status
|
||||
// @Tags service-control
|
||||
// @Tags system
|
||||
// @Success 200 {array} string
|
||||
// @Router /v1/service-control [get]
|
||||
// @Router /system/health [get]
|
||||
func (ac *SystemController) getFirst(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
return c.SendString(configs.Version)
|
||||
}
|
||||
|
||||
138
local/controller/websocket.go
Normal file
138
local/controller/websocket.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/middleware"
|
||||
"acc-server-manager/local/service"
|
||||
"acc-server-manager/local/utl/common"
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebSocketController struct {
|
||||
webSocketService *service.WebSocketService
|
||||
jwtHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
func NewWebSocketController(
|
||||
wsService *service.WebSocketService,
|
||||
jwtHandler *jwt.OpenJWTHandler,
|
||||
routeGroups *common.RouteGroups,
|
||||
auth *middleware.AuthMiddleware,
|
||||
) *WebSocketController {
|
||||
wsc := &WebSocketController{
|
||||
webSocketService: wsService,
|
||||
jwtHandler: jwtHandler,
|
||||
}
|
||||
|
||||
wsRoutes := routeGroups.WebSocket
|
||||
wsRoutes.Use("/", wsc.upgradeWebSocket)
|
||||
wsRoutes.Get("/", websocket.New(wsc.handleWebSocket))
|
||||
|
||||
return wsc
|
||||
}
|
||||
|
||||
func (wsc *WebSocketController) upgradeWebSocket(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
token = c.Get("Authorization")
|
||||
if token != "" && len(token) > 7 && token[:7] == "Bearer " {
|
||||
token = token[7:]
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Missing authentication token")
|
||||
}
|
||||
|
||||
claims, err := wsc.jwtHandler.ValidateToken(token)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Invalid authentication token")
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(claims.UserID)
|
||||
if err != nil {
|
||||
return fiber.NewError(fiber.StatusUnauthorized, "Invalid user ID in token")
|
||||
}
|
||||
|
||||
c.Locals("userID", userID)
|
||||
c.Locals("username", claims.UserID)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required")
|
||||
}
|
||||
|
||||
func (wsc *WebSocketController) handleWebSocket(c *websocket.Conn) {
|
||||
connID := uuid.New().String()
|
||||
|
||||
userID, ok := c.Locals("userID").(uuid.UUID)
|
||||
if !ok {
|
||||
logging.Error("Failed to get user ID from WebSocket connection")
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
username, _ := c.Locals("username").(string)
|
||||
logging.Info("WebSocket connection established for user: %s (ID: %s)", username, userID.String())
|
||||
|
||||
wsc.webSocketService.AddConnection(connID, c, &userID)
|
||||
|
||||
defer func() {
|
||||
wsc.webSocketService.RemoveConnection(connID)
|
||||
logging.Info("WebSocket connection closed for user: %s", username)
|
||||
}()
|
||||
|
||||
for {
|
||||
messageType, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
logging.Error("WebSocket error for user %s: %v", username, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case websocket.TextMessage:
|
||||
wsc.handleTextMessage(connID, userID, message)
|
||||
case websocket.BinaryMessage:
|
||||
logging.Debug("Received binary message from user %s (not supported)", username)
|
||||
case websocket.PingMessage:
|
||||
if err := c.WriteMessage(websocket.PongMessage, nil); err != nil {
|
||||
logging.Error("Failed to send pong to user %s: %v", username, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wsc *WebSocketController) handleTextMessage(connID string, userID uuid.UUID, message []byte) {
|
||||
logging.Debug("Received WebSocket message from user %s: %s", userID.String(), string(message))
|
||||
|
||||
messageStr := string(message)
|
||||
if len(messageStr) > 10 && messageStr[:9] == "server_id" {
|
||||
if serverIDStr := messageStr[10:]; len(serverIDStr) > 0 {
|
||||
if serverID, err := uuid.Parse(serverIDStr); err == nil {
|
||||
wsc.webSocketService.SetServerID(connID, serverID)
|
||||
logging.Info("Associated WebSocket connection %s with server %s", connID, serverID.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wsc *WebSocketController) GetWebSocketUpgrade() fiber.Handler {
|
||||
return wsc.upgradeWebSocket
|
||||
}
|
||||
|
||||
func (wsc *WebSocketController) GetWebSocketHandler() func(*websocket.Conn) {
|
||||
return wsc.handleWebSocket
|
||||
}
|
||||
|
||||
func (wsc *WebSocketController) BroadcastServerCreationProgress(serverID uuid.UUID, step string, status string, message string) {
|
||||
logging.Info("Broadcasting server creation progress: %s - %s: %s", serverID.String(), step, status)
|
||||
}
|
||||
@@ -10,12 +10,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AccessKeyMiddleware provides authentication and permission middleware.
|
||||
type AccessKeyMiddleware struct {
|
||||
userInfo CachedUserInfo
|
||||
}
|
||||
|
||||
// NewAccessKeyMiddleware creates a new AccessKeyMiddleware.
|
||||
func NewAccessKeyMiddleware() *AccessKeyMiddleware {
|
||||
auth := &AccessKeyMiddleware{
|
||||
userInfo: CachedUserInfo{UserID: uuid.New().String(), Username: "access_key", RoleName: "Admin", Permissions: map[string]bool{
|
||||
@@ -25,9 +23,7 @@ func NewAccessKeyMiddleware() *AccessKeyMiddleware {
|
||||
return auth
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AccessKeyMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
// Log authentication attempt
|
||||
ip := ctx.IP()
|
||||
userAgent := ctx.Get("User-Agent")
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CachedUserInfo holds cached user authentication and permission data
|
||||
type CachedUserInfo struct {
|
||||
UserID string
|
||||
Username string
|
||||
@@ -25,34 +24,49 @@ type CachedUserInfo struct {
|
||||
CachedAt time.Time
|
||||
}
|
||||
|
||||
// AuthMiddleware provides authentication and permission middleware.
|
||||
type AuthMiddleware struct {
|
||||
membershipService *service.MembershipService
|
||||
cache *cache.InMemoryCache
|
||||
securityMW *security.SecurityMiddleware
|
||||
jwtHandler *jwt.JWTHandler
|
||||
openJWTHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new AuthMiddleware.
|
||||
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *AuthMiddleware {
|
||||
func NewAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache, jwtHandler *jwt.JWTHandler, openJWTHandler *jwt.OpenJWTHandler) *AuthMiddleware {
|
||||
auth := &AuthMiddleware{
|
||||
membershipService: ms,
|
||||
cache: cache,
|
||||
securityMW: security.NewSecurityMiddleware(),
|
||||
jwtHandler: jwtHandler,
|
||||
openJWTHandler: openJWTHandler,
|
||||
}
|
||||
|
||||
// Set up bidirectional relationship for cache invalidation
|
||||
ms.SetCacheInvalidator(auth)
|
||||
|
||||
return auth
|
||||
}
|
||||
|
||||
// Authenticate is a middleware for JWT authentication with enhanced security.
|
||||
func (m *AuthMiddleware) AuthenticateOpen(ctx *fiber.Ctx) error {
|
||||
return m.AuthenticateWithHandler(m.openJWTHandler.JWTHandler, true, ctx)
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
// Log authentication attempt
|
||||
return m.AuthenticateWithHandler(m.jwtHandler, false, ctx)
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) AuthenticateWithHandler(jwtHandler *jwt.JWTHandler, isOpenToken bool, ctx *fiber.Ctx) error {
|
||||
ip := ctx.IP()
|
||||
userAgent := ctx.Get("User-Agent")
|
||||
|
||||
authHeader := ctx.Get("Authorization")
|
||||
|
||||
if jwtHandler.IsOpenToken && !isOpenToken {
|
||||
logging.Error("Authentication failed: attempting to authenticate with open token")
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "Wrong token type used",
|
||||
})
|
||||
}
|
||||
|
||||
if authHeader == "" {
|
||||
logging.Error("Authentication failed: missing Authorization header from IP %s", ip)
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
@@ -68,7 +82,6 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Validate token length to prevent potential attacks
|
||||
token := parts[1]
|
||||
if len(token) < 10 || len(token) > 2048 {
|
||||
logging.Error("Authentication failed: invalid token length from IP %s", ip)
|
||||
@@ -77,7 +90,7 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
claims, err := jwt.ValidateToken(token)
|
||||
claims, err := jwtHandler.ValidateToken(token)
|
||||
if err != nil {
|
||||
logging.Error("Authentication failed: invalid token from IP %s, User-Agent: %s, Error: %v", ip, userAgent, err)
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
@@ -85,7 +98,13 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Additional security: validate user ID format
|
||||
if !jwtHandler.IsOpenToken && claims.IsOpenToken {
|
||||
logging.Error("Authentication failed: attempting to authenticate with open token")
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "Wrong token type used",
|
||||
})
|
||||
}
|
||||
|
||||
if claims.UserID == "" || len(claims.UserID) < 10 {
|
||||
logging.Error("Authentication failed: invalid user ID in token from IP %s", ip)
|
||||
return ctx.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
@@ -99,7 +118,6 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
ctx.Locals("userInfo", userInfo)
|
||||
ctx.Locals("authTime", time.Now())
|
||||
} else {
|
||||
// Preload and cache user info to avoid database queries on permission checks
|
||||
userInfo, err := m.getCachedUserInfo(ctx.UserContext(), claims.UserID)
|
||||
if err != nil {
|
||||
logging.Error("Authentication failed: unable to load user info for %s from IP %s: %v", claims.UserID, ip, err)
|
||||
@@ -117,7 +135,6 @@ func (m *AuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
// HasPermission is a middleware for checking user permissions with enhanced logging.
|
||||
func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
userID, ok := ctx.Locals("userID").(string)
|
||||
@@ -132,7 +149,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
// Validate permission parameter
|
||||
if requiredPermission == "" {
|
||||
logging.Error("Permission check failed: empty permission requirement")
|
||||
return ctx.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
@@ -140,7 +156,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
})
|
||||
}
|
||||
|
||||
// Use cached user info from authentication step - no database queries needed
|
||||
userInfo, ok := ctx.Locals("userInfo").(*CachedUserInfo)
|
||||
if !ok {
|
||||
logging.Error("Permission check failed: no cached user info in context from IP %s", ctx.IP())
|
||||
@@ -149,7 +164,6 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
})
|
||||
}
|
||||
|
||||
// Check if user has permission using cached data
|
||||
has := m.hasPermissionFromCache(userInfo, requiredPermission)
|
||||
|
||||
if !has {
|
||||
@@ -164,16 +178,13 @@ func (m *AuthMiddleware) HasPermission(requiredPermission string) fiber.Handler
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit applies rate limiting specifically for authentication endpoints
|
||||
func (m *AuthMiddleware) AuthRateLimit() fiber.Handler {
|
||||
return m.securityMW.AuthRateLimit()
|
||||
}
|
||||
|
||||
// RequireHTTPS redirects HTTP requests to HTTPS in production
|
||||
func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
if ctx.Protocol() != "https" && ctx.Get("X-Forwarded-Proto") != "https" {
|
||||
// Allow HTTP in development/testing
|
||||
if ctx.Hostname() != "localhost" && ctx.Hostname() != "127.0.0.1" {
|
||||
httpsURL := "https://" + ctx.Hostname() + ctx.OriginalURL()
|
||||
return ctx.Redirect(httpsURL, fiber.StatusMovedPermanently)
|
||||
@@ -183,11 +194,9 @@ func (m *AuthMiddleware) RequireHTTPS() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// getCachedUserInfo retrieves and caches complete user information including permissions
|
||||
func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (*CachedUserInfo, error) {
|
||||
cacheKey := fmt.Sprintf("userinfo:%s", userID)
|
||||
|
||||
// Try cache first
|
||||
if cached, found := m.cache.Get(cacheKey); found {
|
||||
if userInfo, ok := cached.(*CachedUserInfo); ok {
|
||||
logging.DebugWithContext("AUTH_CACHE", "User info for %s found in cache", userID)
|
||||
@@ -195,13 +204,11 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss - load from database
|
||||
user, err := m.membershipService.GetUserWithPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build permission map for fast lookups
|
||||
permissions := make(map[string]bool)
|
||||
for _, p := range user.Role.Permissions {
|
||||
permissions[p.Name] = true
|
||||
@@ -215,34 +222,26 @@ func (m *AuthMiddleware) getCachedUserInfo(ctx context.Context, userID string) (
|
||||
CachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Cache for 15 minutes
|
||||
m.cache.Set(cacheKey, userInfo, 15*time.Minute)
|
||||
logging.DebugWithContext("AUTH_CACHE", "User info for %s cached with %d permissions", userID, len(permissions))
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// hasPermissionFromCache checks permissions using cached user info (no database queries)
|
||||
func (m *AuthMiddleware) hasPermissionFromCache(userInfo *CachedUserInfo, permission string) bool {
|
||||
// Super Admin and Admin have all permissions
|
||||
if userInfo.RoleName == "Super Admin" || userInfo.RoleName == "Admin" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check specific permission in cached map
|
||||
return userInfo.Permissions[permission]
|
||||
}
|
||||
|
||||
// InvalidateUserPermissions removes cached user info for a user
|
||||
func (m *AuthMiddleware) InvalidateUserPermissions(userID string) {
|
||||
cacheKey := fmt.Sprintf("userinfo:%s", userID)
|
||||
m.cache.Delete(cacheKey)
|
||||
logging.InfoWithContext("AUTH_CACHE", "User info cache invalidated for user %s", userID)
|
||||
}
|
||||
|
||||
// InvalidateAllUserPermissions clears all cached user info (useful for role/permission changes)
|
||||
func (m *AuthMiddleware) InvalidateAllUserPermissions() {
|
||||
// This would need to be implemented based on your cache interface
|
||||
// For now, just log that invalidation was requested
|
||||
logging.InfoWithContext("AUTH_CACHE", "All user info caches invalidation requested")
|
||||
}
|
||||
|
||||
@@ -7,25 +7,20 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// RequestLoggingMiddleware logs HTTP requests and responses
|
||||
type RequestLoggingMiddleware struct {
|
||||
infoLogger *logging.InfoLogger
|
||||
}
|
||||
|
||||
// NewRequestLoggingMiddleware creates a new request logging middleware
|
||||
func NewRequestLoggingMiddleware() *RequestLoggingMiddleware {
|
||||
return &RequestLoggingMiddleware{
|
||||
infoLogger: logging.GetInfoLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns the middleware handler function
|
||||
func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Record start time
|
||||
start := time.Now()
|
||||
|
||||
// Log incoming request
|
||||
userAgent := c.Get("User-Agent")
|
||||
if userAgent == "" {
|
||||
userAgent = "Unknown"
|
||||
@@ -33,17 +28,13 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
|
||||
|
||||
rlm.infoLogger.LogRequest(c.Method(), c.OriginalURL(), userAgent)
|
||||
|
||||
// Continue to next handler
|
||||
err := c.Next()
|
||||
|
||||
// Calculate duration
|
||||
duration := time.Since(start)
|
||||
|
||||
// Log response
|
||||
statusCode := c.Response().StatusCode()
|
||||
rlm.infoLogger.LogResponse(c.Method(), c.OriginalURL(), statusCode, duration.String())
|
||||
|
||||
// Log error if present
|
||||
if err != nil {
|
||||
logging.ErrorWithContext("REQUEST_MIDDLEWARE", "Request failed: %v", err)
|
||||
}
|
||||
@@ -52,10 +43,8 @@ func (rlm *RequestLoggingMiddleware) Handler() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Global request logging middleware instance
|
||||
var globalRequestLoggingMiddleware *RequestLoggingMiddleware
|
||||
|
||||
// GetRequestLoggingMiddleware returns the global request logging middleware
|
||||
func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
|
||||
if globalRequestLoggingMiddleware == nil {
|
||||
globalRequestLoggingMiddleware = NewRequestLoggingMiddleware()
|
||||
@@ -63,7 +52,6 @@ func GetRequestLoggingMiddleware() *RequestLoggingMiddleware {
|
||||
return globalRequestLoggingMiddleware
|
||||
}
|
||||
|
||||
// Handler returns the global request logging middleware handler
|
||||
func Handler() fiber.Handler {
|
||||
return GetRequestLoggingMiddleware().Handler()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/graceful"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -10,88 +11,76 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// RateLimiter stores rate limiting information
|
||||
type RateLimiter struct {
|
||||
requests map[string][]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
requests: make(map[string][]time.Time),
|
||||
}
|
||||
|
||||
// Clean up old entries every 5 minutes
|
||||
go rl.cleanup()
|
||||
shutdownManager := graceful.GetManager()
|
||||
shutdownManager.RunGoroutine(func(ctx context.Context) {
|
||||
rl.cleanupWithContext(ctx)
|
||||
})
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup removes old entries from the rate limiter
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
func (rl *RateLimiter) cleanupWithContext(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mutex.Lock()
|
||||
now := time.Now()
|
||||
for key, times := range rl.requests {
|
||||
// Remove entries older than 1 hour
|
||||
filtered := make([]time.Time, 0, len(times))
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < time.Hour {
|
||||
filtered = append(filtered, t)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rl.mutex.Lock()
|
||||
now := time.Now()
|
||||
for key, times := range rl.requests {
|
||||
filtered := make([]time.Time, 0, len(times))
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < time.Hour {
|
||||
filtered = append(filtered, t)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(rl.requests, key)
|
||||
} else {
|
||||
rl.requests[key] = filtered
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(rl.requests, key)
|
||||
} else {
|
||||
rl.requests[key] = filtered
|
||||
}
|
||||
rl.mutex.Unlock()
|
||||
}
|
||||
rl.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityMiddleware provides comprehensive security middleware
|
||||
type SecurityMiddleware struct {
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware creates a new security middleware
|
||||
func NewSecurityMiddleware() *SecurityMiddleware {
|
||||
return &SecurityMiddleware{
|
||||
rateLimiter: NewRateLimiter(),
|
||||
}
|
||||
}
|
||||
|
||||
// SecurityHeaders adds security headers to responses
|
||||
func (sm *SecurityMiddleware) SecurityHeaders() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Prevent MIME type sniffing
|
||||
c.Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Prevent clickjacking
|
||||
c.Set("X-Frame-Options", "DENY")
|
||||
|
||||
// Enable XSS protection
|
||||
c.Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Prevent referrer leakage
|
||||
c.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Content Security Policy
|
||||
c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none'")
|
||||
|
||||
// Permissions Policy
|
||||
c.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), interest-cohort=()")
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit implements rate limiting for API endpoints
|
||||
func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ip := c.IP()
|
||||
@@ -103,7 +92,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
now := time.Now()
|
||||
requests := sm.rateLimiter.requests[key]
|
||||
|
||||
// Remove requests older than duration
|
||||
filtered := make([]time.Time, 0, len(requests))
|
||||
for _, t := range requests {
|
||||
if now.Sub(t) < duration {
|
||||
@@ -111,7 +99,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if limit is exceeded
|
||||
if len(filtered) >= maxRequests {
|
||||
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
|
||||
"error": "Rate limit exceeded",
|
||||
@@ -119,7 +106,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
})
|
||||
}
|
||||
|
||||
// Add current request
|
||||
filtered = append(filtered, now)
|
||||
sm.rateLimiter.requests[key] = filtered
|
||||
|
||||
@@ -127,7 +113,6 @@ func (sm *SecurityMiddleware) RateLimit(maxRequests int, duration time.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit implements stricter rate limiting for authentication endpoints
|
||||
func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ip := c.IP()
|
||||
@@ -140,7 +125,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
now := time.Now()
|
||||
requests := sm.rateLimiter.requests[key]
|
||||
|
||||
// Remove requests older than 15 minutes
|
||||
filtered := make([]time.Time, 0, len(requests))
|
||||
for _, t := range requests {
|
||||
if now.Sub(t) < 15*time.Minute {
|
||||
@@ -148,7 +132,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if limit is exceeded (5 requests per 15 minutes for auth)
|
||||
if len(filtered) >= 5 {
|
||||
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{
|
||||
"error": "Too many authentication attempts",
|
||||
@@ -156,7 +139,6 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// Add current request
|
||||
filtered = append(filtered, now)
|
||||
sm.rateLimiter.requests[key] = filtered
|
||||
|
||||
@@ -164,20 +146,16 @@ func (sm *SecurityMiddleware) AuthRateLimit() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// InputSanitization sanitizes user input to prevent XSS and injection attacks
|
||||
func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Sanitize query parameters
|
||||
c.Request().URI().QueryArgs().VisitAll(func(key, value []byte) {
|
||||
sanitized := sanitizeInput(string(value))
|
||||
c.Request().URI().QueryArgs().Set(string(key), sanitized)
|
||||
})
|
||||
|
||||
// Store original body for processing
|
||||
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
|
||||
body := c.Body()
|
||||
if len(body) > 0 {
|
||||
// Basic sanitization - remove potentially dangerous patterns
|
||||
sanitized := sanitizeInput(string(body))
|
||||
c.Request().SetBodyString(sanitized)
|
||||
}
|
||||
@@ -187,15 +165,14 @@ func (sm *SecurityMiddleware) InputSanitization() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeInput removes potentially dangerous patterns from input
|
||||
func sanitizeInput(input string) string {
|
||||
// Remove common XSS patterns
|
||||
dangerous := []string{
|
||||
"<script",
|
||||
"</script>",
|
||||
"javascript:",
|
||||
"vbscript:",
|
||||
"data:text/html",
|
||||
"data:application",
|
||||
"onload=",
|
||||
"onerror=",
|
||||
"onclick=",
|
||||
@@ -204,28 +181,48 @@ func sanitizeInput(input string) string {
|
||||
"onblur=",
|
||||
"onchange=",
|
||||
"onsubmit=",
|
||||
"onkeydown=",
|
||||
"onkeyup=",
|
||||
"<iframe",
|
||||
"<object",
|
||||
"<embed",
|
||||
"<link",
|
||||
"<meta",
|
||||
"<style",
|
||||
"<form",
|
||||
"<input",
|
||||
"<button",
|
||||
"<svg",
|
||||
"<math",
|
||||
"expression(",
|
||||
"@import",
|
||||
"url(",
|
||||
"\\x",
|
||||
"\\u",
|
||||
"&#x",
|
||||
"&#",
|
||||
}
|
||||
|
||||
result := strings.ToLower(input)
|
||||
result := input
|
||||
lowerInput := strings.ToLower(input)
|
||||
|
||||
for _, pattern := range dangerous {
|
||||
result = strings.ReplaceAll(result, pattern, "")
|
||||
if strings.Contains(lowerInput, pattern) {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// If the sanitized version is very different, it might be malicious
|
||||
if len(result) < len(input)/2 {
|
||||
if strings.Contains(result, "\x00") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return input
|
||||
if len(strings.TrimSpace(result)) == 0 && len(input) > 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidateContentType ensures only expected content types are accepted
|
||||
func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
|
||||
@@ -236,7 +233,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
|
||||
})
|
||||
}
|
||||
|
||||
// Check if content type is allowed
|
||||
allowed := false
|
||||
for _, allowedType := range allowedTypes {
|
||||
if strings.Contains(contentType, allowedType) {
|
||||
@@ -256,7 +252,6 @@ func (sm *SecurityMiddleware) ValidateContentType(allowedTypes ...string) fiber.
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUserAgent blocks requests with suspicious or missing user agents
|
||||
func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
|
||||
suspiciousAgents := []string{
|
||||
"sqlmap",
|
||||
@@ -267,21 +262,19 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
|
||||
"dirb",
|
||||
"dirbuster",
|
||||
"wpscan",
|
||||
"curl/7.0", // Very old curl versions
|
||||
"wget/1.0", // Very old wget versions
|
||||
"curl/7.0",
|
||||
"wget/1.0",
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
userAgent := strings.ToLower(c.Get("User-Agent"))
|
||||
|
||||
// Block empty user agents
|
||||
if userAgent == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": "User-Agent header is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Block suspicious user agents
|
||||
for _, suspicious := range suspiciousAgents {
|
||||
if strings.Contains(userAgent, suspicious) {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
||||
@@ -294,7 +287,6 @@ func (sm *SecurityMiddleware) ValidateUserAgent() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// RequestSizeLimit limits the size of incoming requests
|
||||
func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
if c.Method() == "POST" || c.Method() == "PUT" || c.Method() == "PATCH" {
|
||||
@@ -311,19 +303,15 @@ func (sm *SecurityMiddleware) RequestSizeLimit(maxSize int) fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// LogSecurityEvents logs security-related events
|
||||
func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
start := time.Now()
|
||||
|
||||
// Process request
|
||||
err := c.Next()
|
||||
|
||||
// Log suspicious activity
|
||||
status := c.Response().StatusCode()
|
||||
if status == 401 || status == 403 || status == 429 {
|
||||
duration := time.Since(start)
|
||||
// In a real implementation, you would send this to your logging system
|
||||
fmt.Printf("[SECURITY] %s %s %s %d %v %s\n",
|
||||
time.Now().Format(time.RFC3339),
|
||||
c.IP(),
|
||||
@@ -338,7 +326,6 @@ func (sm *SecurityMiddleware) LogSecurityEvents() fiber.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
// TimeoutMiddleware adds request timeout
|
||||
func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
|
||||
@@ -349,3 +336,24 @@ func (sm *SecurityMiddleware) TimeoutMiddleware(timeout time.Duration) fiber.Han
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SecurityMiddleware) RequestContextTimeout(timeout time.Duration) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
ctx, cancel := context.WithTimeout(c.UserContext(), timeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.Next()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return c.Status(fiber.StatusRequestTimeout).JSON(fiber.Map{
|
||||
"error": "Request timeout",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,21 +9,17 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Migration001UpgradePasswordSecurity migrates existing user passwords from encrypted to hashed format
|
||||
type Migration001UpgradePasswordSecurity struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewMigration001UpgradePasswordSecurity creates a new password security migration
|
||||
func NewMigration001UpgradePasswordSecurity(db *gorm.DB) *Migration001UpgradePasswordSecurity {
|
||||
return &Migration001UpgradePasswordSecurity{DB: db}
|
||||
}
|
||||
|
||||
// Up executes the migration
|
||||
func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
logging.Info("Starting password security upgrade migration...")
|
||||
|
||||
// Check if migration has already been applied
|
||||
var migrationRecord MigrationRecord
|
||||
err := m.DB.Where("migration_name = ?", "001_upgrade_password_security").First(&migrationRecord).Error
|
||||
if err == nil {
|
||||
@@ -31,12 +27,10 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create migration tracking table if it doesn't exist
|
||||
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to create migration tracking table: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx := m.DB.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to start transaction: %v", tx.Error)
|
||||
@@ -47,16 +41,13 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
}
|
||||
}()
|
||||
|
||||
// Add a backup column for old passwords (temporary)
|
||||
if err := tx.Exec("ALTER TABLE users ADD COLUMN password_backup TEXT").Error; err != nil {
|
||||
// Column might already exist, ignore if it's a duplicate column error
|
||||
if !isDuplicateColumnError(err) {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to add backup column: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all users with encrypted passwords
|
||||
var users []UserForMigration
|
||||
if err := tx.Find(&users).Error; err != nil {
|
||||
tx.Rollback()
|
||||
@@ -72,19 +63,15 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
if err := m.migrateUserPassword(tx, &user); err != nil {
|
||||
logging.Error("Failed to migrate user %s (ID: %s): %v", user.Username, user.ID, err)
|
||||
failedCount++
|
||||
// Continue with other users rather than failing completely
|
||||
continue
|
||||
}
|
||||
migratedCount++
|
||||
}
|
||||
|
||||
// Remove backup column after successful migration
|
||||
if err := tx.Exec("ALTER TABLE users DROP COLUMN password_backup").Error; err != nil {
|
||||
logging.Error("Failed to remove backup column (non-critical): %v", err)
|
||||
// Don't fail the migration for this
|
||||
}
|
||||
|
||||
// Record successful migration
|
||||
migrationRecord = MigrationRecord{
|
||||
MigrationName: "001_upgrade_password_security",
|
||||
AppliedAt: "datetime('now')",
|
||||
@@ -97,7 +84,6 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
return fmt.Errorf("failed to record migration: %v", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration: %v", err)
|
||||
}
|
||||
@@ -111,32 +97,24 @@ func (m *Migration001UpgradePasswordSecurity) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateUserPassword migrates a single user's password
|
||||
func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, user *UserForMigration) error {
|
||||
// Skip if password is already hashed (bcrypt hashes start with $2a$, $2b$, or $2y$)
|
||||
if isAlreadyHashed(user.Password) {
|
||||
logging.Debug("User %s already has hashed password, skipping", user.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Backup original password
|
||||
if err := tx.Model(user).Update("password_backup", user.Password).Error; err != nil {
|
||||
return fmt.Errorf("failed to backup password: %v", err)
|
||||
}
|
||||
|
||||
// Try to decrypt the old password
|
||||
var plainPassword string
|
||||
|
||||
// First, try to decrypt using the old encryption method
|
||||
decrypted, err := decryptOldPassword(user.Password)
|
||||
if err != nil {
|
||||
// If decryption fails, the password might already be plain text or corrupted
|
||||
logging.Error("Failed to decrypt password for user %s, treating as plain text: %v", user.Username, err)
|
||||
|
||||
// Use original password as-is (might be plain text from development)
|
||||
plainPassword = user.Password
|
||||
|
||||
// Validate it's not obviously encrypted data
|
||||
if len(plainPassword) > 100 || containsBinaryData(plainPassword) {
|
||||
return fmt.Errorf("password appears to be corrupted encrypted data")
|
||||
}
|
||||
@@ -144,7 +122,6 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
|
||||
plainPassword = decrypted
|
||||
}
|
||||
|
||||
// Validate plain password
|
||||
if plainPassword == "" {
|
||||
return errors.New("decrypted password is empty")
|
||||
}
|
||||
@@ -153,13 +130,11 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
|
||||
return errors.New("password too short after decryption")
|
||||
}
|
||||
|
||||
// Hash the plain password using bcrypt
|
||||
hashedPassword, err := password.HashPassword(plainPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
// Update with hashed password
|
||||
if err := tx.Model(user).Update("password", hashedPassword).Error; err != nil {
|
||||
return fmt.Errorf("failed to update password: %v", err)
|
||||
}
|
||||
@@ -168,19 +143,16 @@ func (m *Migration001UpgradePasswordSecurity) migrateUserPassword(tx *gorm.DB, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserForMigration represents a user record for migration purposes
|
||||
type UserForMigration struct {
|
||||
ID string `gorm:"column:id"`
|
||||
Username string `gorm:"column:username"`
|
||||
Password string `gorm:"column:password"`
|
||||
}
|
||||
|
||||
// TableName specifies the table name for GORM
|
||||
func (UserForMigration) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// MigrationRecord tracks applied migrations
|
||||
type MigrationRecord struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
MigrationName string `gorm:"unique;not null"`
|
||||
@@ -189,49 +161,38 @@ type MigrationRecord struct {
|
||||
Notes string
|
||||
}
|
||||
|
||||
// TableName specifies the table name for GORM
|
||||
func (MigrationRecord) TableName() string {
|
||||
return "migration_records"
|
||||
}
|
||||
|
||||
// isAlreadyHashed checks if a password is already bcrypt hashed
|
||||
func isAlreadyHashed(password string) bool {
|
||||
return len(password) >= 60 && (password[:4] == "$2a$" || password[:4] == "$2b$" || password[:4] == "$2y$")
|
||||
}
|
||||
|
||||
// containsBinaryData checks if a string contains binary data
|
||||
func containsBinaryData(s string) bool {
|
||||
for _, b := range []byte(s) {
|
||||
if b < 32 && b != 9 && b != 10 && b != 13 { // Allow tab, newline, carriage return
|
||||
if b < 32 && b != 9 && b != 10 && b != 13 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isDuplicateColumnError checks if an error is due to duplicate column
|
||||
func isDuplicateColumnError(err error) bool {
|
||||
errStr := err.Error()
|
||||
return fmt.Sprintf("%v", errStr) == "duplicate column name: password_backup" ||
|
||||
fmt.Sprintf("%v", errStr) == "SQLITE_ERROR: duplicate column name: password_backup"
|
||||
}
|
||||
|
||||
// decryptOldPassword attempts to decrypt using the old encryption method
|
||||
// This is a simplified version of the old DecryptPassword function
|
||||
func decryptOldPassword(encryptedPassword string) (string, error) {
|
||||
// This would use the old decryption logic
|
||||
// For now, we'll return an error to force treating as plain text
|
||||
// In a real scenario, you'd implement the old decryption here
|
||||
return "", errors.New("old decryption not implemented - treating as plain text")
|
||||
}
|
||||
|
||||
// Down reverses the migration (if needed)
|
||||
func (m *Migration001UpgradePasswordSecurity) Down() error {
|
||||
logging.Error("Password security migration rollback is not supported for security reasons")
|
||||
return errors.New("password security migration rollback is not supported")
|
||||
}
|
||||
|
||||
// RunMigration is a convenience function to run the migration
|
||||
func RunPasswordSecurityMigration(db *gorm.DB) error {
|
||||
migration := NewMigration001UpgradePasswordSecurity(db)
|
||||
return migration.Up()
|
||||
|
||||
@@ -9,21 +9,17 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Migration002MigrateToUUID migrates tables from integer IDs to UUIDs
|
||||
type Migration002MigrateToUUID struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewMigration002MigrateToUUID creates a new UUID migration
|
||||
func NewMigration002MigrateToUUID(db *gorm.DB) *Migration002MigrateToUUID {
|
||||
return &Migration002MigrateToUUID{DB: db}
|
||||
}
|
||||
|
||||
// Up executes the migration
|
||||
func (m *Migration002MigrateToUUID) Up() error {
|
||||
logging.Info("Checking UUID migration...")
|
||||
|
||||
// Check if migration is needed by looking at the servers table structure
|
||||
if !m.needsMigration() {
|
||||
logging.Info("UUID migration not needed - tables already use UUID primary keys")
|
||||
return nil
|
||||
@@ -31,7 +27,6 @@ func (m *Migration002MigrateToUUID) Up() error {
|
||||
|
||||
logging.Info("Starting UUID migration...")
|
||||
|
||||
// Check if migration has already been applied
|
||||
var migrationRecord MigrationRecord
|
||||
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
|
||||
if err == nil {
|
||||
@@ -39,12 +34,10 @@ func (m *Migration002MigrateToUUID) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create migration tracking table if it doesn't exist
|
||||
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to create migration tracking table: %v", err)
|
||||
}
|
||||
|
||||
// Execute the UUID migration using the existing migration function
|
||||
logging.Info("Executing UUID migration...")
|
||||
if err := runUUIDMigrationSQL(m.DB); err != nil {
|
||||
return fmt.Errorf("failed to execute UUID migration: %v", err)
|
||||
@@ -54,9 +47,7 @@ func (m *Migration002MigrateToUUID) Up() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// needsMigration checks if the UUID migration is needed by examining table structure
|
||||
func (m *Migration002MigrateToUUID) needsMigration() bool {
|
||||
// Check if servers table exists and has integer primary key
|
||||
var result struct {
|
||||
Type string `gorm:"column:type"`
|
||||
}
|
||||
@@ -67,29 +58,22 @@ func (m *Migration002MigrateToUUID) needsMigration() bool {
|
||||
`).Scan(&result).Error
|
||||
|
||||
if err != nil || result.Type == "" {
|
||||
// Table doesn't exist or no primary key found - assume no migration needed
|
||||
return false
|
||||
}
|
||||
|
||||
// If the primary key is INTEGER, we need migration
|
||||
// If it's TEXT (UUID), migration already done
|
||||
return result.Type == "INTEGER" || result.Type == "integer"
|
||||
}
|
||||
|
||||
// Down reverses the migration (not implemented for safety)
|
||||
func (m *Migration002MigrateToUUID) Down() error {
|
||||
logging.Error("UUID migration rollback is not supported for data safety reasons")
|
||||
return fmt.Errorf("UUID migration rollback is not supported")
|
||||
}
|
||||
|
||||
// runUUIDMigrationSQL executes the UUID migration using the SQL file
|
||||
func runUUIDMigrationSQL(db *gorm.DB) error {
|
||||
// Disable foreign key constraints during migration
|
||||
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
|
||||
return fmt.Errorf("failed to disable foreign keys: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to start transaction: %v", tx.Error)
|
||||
@@ -101,25 +85,21 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// Read the migration SQL from file
|
||||
sqlPath := filepath.Join("scripts", "migrations", "002_migrate_servers_to_uuid.sql")
|
||||
migrationSQL, err := ioutil.ReadFile(sqlPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read migration SQL file: %v", err)
|
||||
}
|
||||
|
||||
// Execute the migration
|
||||
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to execute migration: %v", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration: %v", err)
|
||||
}
|
||||
|
||||
// Re-enable foreign key constraints
|
||||
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
|
||||
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
|
||||
}
|
||||
@@ -127,7 +107,6 @@ func runUUIDMigrationSQL(db *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RunUUIDMigration is a convenience function to run the migration
|
||||
func RunUUIDMigration(db *gorm.DB) error {
|
||||
migration := NewMigration002MigrateToUUID(db)
|
||||
return migration.Up()
|
||||
|
||||
106
local/migrations/003_update_state_history_sessions.go
Normal file
106
local/migrations/003_update_state_history_sessions.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UpdateStateHistorySessions struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewUpdateStateHistorySessions(db *gorm.DB) *UpdateStateHistorySessions {
|
||||
return &UpdateStateHistorySessions{DB: db}
|
||||
}
|
||||
|
||||
func (m *UpdateStateHistorySessions) Up() error {
|
||||
logging.Info("Checking UUID migration...")
|
||||
|
||||
if !m.needsMigration() {
|
||||
logging.Info("UUID migration not needed - tables already use UUID primary keys")
|
||||
return nil
|
||||
}
|
||||
|
||||
logging.Info("Starting UUID migration...")
|
||||
|
||||
var migrationRecord MigrationRecord
|
||||
err := m.DB.Where("migration_name = ?", "002_migrate_to_uuid").First(&migrationRecord).Error
|
||||
if err == nil {
|
||||
logging.Info("UUID migration already applied, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.DB.AutoMigrate(&MigrationRecord{}); err != nil {
|
||||
return fmt.Errorf("failed to create migration tracking table: %v", err)
|
||||
}
|
||||
|
||||
logging.Info("Executing UUID migration...")
|
||||
if err := runUUIDMigrationSQL(m.DB); err != nil {
|
||||
return fmt.Errorf("failed to execute UUID migration: %v", err)
|
||||
}
|
||||
|
||||
logging.Info("UUID migration completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *UpdateStateHistorySessions) needsMigration() bool {
|
||||
var result struct {
|
||||
Exists bool `gorm:"column:exists"`
|
||||
}
|
||||
|
||||
err := m.DB.Raw(`
|
||||
SELECT count(*) > 0 as exists FROM state_history
|
||||
WHERE length(session) > 1 LIMIT 1;
|
||||
`).Scan(&result).Error
|
||||
|
||||
if err != nil || !result.Exists {
|
||||
return false
|
||||
}
|
||||
return result.Exists
|
||||
}
|
||||
|
||||
func (m *UpdateStateHistorySessions) Down() error {
|
||||
logging.Error("UUID migration rollback is not supported for data safety reasons")
|
||||
return fmt.Errorf("UUID migration rollback is not supported")
|
||||
}
|
||||
|
||||
func runUpdateStateHistorySessionsMigration(db *gorm.DB) error {
|
||||
if err := db.Exec("PRAGMA foreign_keys=OFF").Error; err != nil {
|
||||
return fmt.Errorf("failed to disable foreign keys: %v", err)
|
||||
}
|
||||
|
||||
tx := db.Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to start transaction: %v", tx.Error)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
migrationSQL := "UPDATE state_history SET session = upper(substr(session, 1, 1));"
|
||||
|
||||
if err := tx.Exec(string(migrationSQL)).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to execute migration: %v", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("failed to commit migration: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Exec("PRAGMA foreign_keys=ON").Error; err != nil {
|
||||
return fmt.Errorf("failed to re-enable foreign keys: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunUpdateStateHistorySessionsMigration(db *gorm.DB) error {
|
||||
migration := NewUpdateStateHistorySessions(db)
|
||||
return migration.Up()
|
||||
}
|
||||
@@ -6,20 +6,17 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusCache represents a cached server status with expiration
|
||||
type StatusCache struct {
|
||||
Status ServiceStatus
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// CacheConfig holds configuration for cache behavior
|
||||
type CacheConfig struct {
|
||||
ExpirationTime time.Duration // How long before a cache entry expires
|
||||
ThrottleTime time.Duration // Minimum time between status checks
|
||||
DefaultStatus ServiceStatus // Default status to return when throttled
|
||||
ExpirationTime time.Duration
|
||||
ThrottleTime time.Duration
|
||||
DefaultStatus ServiceStatus
|
||||
}
|
||||
|
||||
// ServerStatusCache manages cached server statuses
|
||||
type ServerStatusCache struct {
|
||||
sync.RWMutex
|
||||
cache map[string]*StatusCache
|
||||
@@ -27,7 +24,6 @@ type ServerStatusCache struct {
|
||||
lastChecked map[string]time.Time
|
||||
}
|
||||
|
||||
// NewServerStatusCache creates a new server status cache
|
||||
func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
|
||||
return &ServerStatusCache{
|
||||
cache: make(map[string]*StatusCache),
|
||||
@@ -36,12 +32,10 @@ func NewServerStatusCache(config CacheConfig) *ServerStatusCache {
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus retrieves the cached status or indicates if a fresh check is needed
|
||||
func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
// Check if we're being throttled
|
||||
if lastCheck, exists := c.lastChecked[serviceName]; exists {
|
||||
if time.Since(lastCheck) < c.config.ThrottleTime {
|
||||
if cached, ok := c.cache[serviceName]; ok {
|
||||
@@ -51,7 +45,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have a valid cached entry
|
||||
if cached, ok := c.cache[serviceName]; ok {
|
||||
if time.Since(cached.UpdatedAt) < c.config.ExpirationTime {
|
||||
return cached.Status, false
|
||||
@@ -61,7 +54,6 @@ func (c *ServerStatusCache) GetStatus(serviceName string) (ServiceStatus, bool)
|
||||
return StatusUnknown, true
|
||||
}
|
||||
|
||||
// UpdateStatus updates the cache with a new status
|
||||
func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatus) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -73,7 +65,6 @@ func (c *ServerStatusCache) UpdateStatus(serviceName string, status ServiceStatu
|
||||
c.lastChecked[serviceName] = time.Now()
|
||||
}
|
||||
|
||||
// InvalidateStatus removes a specific service from the cache
|
||||
func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -82,7 +73,6 @@ func (c *ServerStatusCache) InvalidateStatus(serviceName string) {
|
||||
delete(c.lastChecked, serviceName)
|
||||
}
|
||||
|
||||
// Clear removes all entries from the cache
|
||||
func (c *ServerStatusCache) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -91,13 +81,11 @@ func (c *ServerStatusCache) Clear() {
|
||||
c.lastChecked = make(map[string]time.Time)
|
||||
}
|
||||
|
||||
// LookupCache provides a generic cache for lookup data
|
||||
type LookupCache struct {
|
||||
sync.RWMutex
|
||||
data map[string]interface{}
|
||||
}
|
||||
|
||||
// NewLookupCache creates a new lookup cache
|
||||
func NewLookupCache() *LookupCache {
|
||||
logging.Debug("Initializing new LookupCache")
|
||||
return &LookupCache{
|
||||
@@ -105,7 +93,6 @@ func NewLookupCache() *LookupCache {
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a cached value by key
|
||||
func (c *LookupCache) Get(key string) (interface{}, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -119,7 +106,6 @@ func (c *LookupCache) Get(key string) (interface{}, bool) {
|
||||
return value, exists
|
||||
}
|
||||
|
||||
// Set stores a value in the cache
|
||||
func (c *LookupCache) Set(key string, value interface{}) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -128,7 +114,6 @@ func (c *LookupCache) Set(key string, value interface{}) {
|
||||
logging.Debug("Cache SET for key: %s", key)
|
||||
}
|
||||
|
||||
// Clear removes all entries from the cache
|
||||
func (c *LookupCache) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -137,13 +122,11 @@ func (c *LookupCache) Clear() {
|
||||
logging.Debug("Cache CLEARED")
|
||||
}
|
||||
|
||||
// ConfigEntry represents a cached configuration entry with its update time
|
||||
type ConfigEntry[T any] struct {
|
||||
Data T
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// getConfigFromCache is a generic helper function to retrieve cached configs
|
||||
func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string, expirationTime time.Duration) (*T, bool) {
|
||||
if entry, ok := cache[serverID]; ok {
|
||||
if time.Since(entry.UpdatedAt) < expirationTime {
|
||||
@@ -157,7 +140,6 @@ func getConfigFromCache[T any](cache map[string]*ConfigEntry[T], serverID string
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// updateConfigInCache is a generic helper function to update cached configs
|
||||
func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID string, data T) {
|
||||
cache[serverID] = &ConfigEntry[T]{
|
||||
Data: data,
|
||||
@@ -166,7 +148,6 @@ func updateConfigInCache[T any](cache map[string]*ConfigEntry[T], serverID strin
|
||||
logging.Debug("Config cache SET for server ID: %s", serverID)
|
||||
}
|
||||
|
||||
// ServerConfigCache manages cached server configurations
|
||||
type ServerConfigCache struct {
|
||||
sync.RWMutex
|
||||
configuration map[string]*ConfigEntry[Configuration]
|
||||
@@ -177,7 +158,6 @@ type ServerConfigCache struct {
|
||||
config CacheConfig
|
||||
}
|
||||
|
||||
// NewServerConfigCache creates a new server configuration cache
|
||||
func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
|
||||
logging.Debug("Initializing new ServerConfigCache with expiration time: %v, throttle time: %v", config.ExpirationTime, config.ThrottleTime)
|
||||
return &ServerConfigCache{
|
||||
@@ -190,7 +170,6 @@ func NewServerConfigCache(config CacheConfig) *ServerConfigCache {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfiguration retrieves a cached configuration
|
||||
func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -198,7 +177,6 @@ func (c *ServerConfigCache) GetConfiguration(serverID string) (*Configuration, b
|
||||
return getConfigFromCache(c.configuration, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetAssistRules retrieves cached assist rules
|
||||
func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -206,7 +184,6 @@ func (c *ServerConfigCache) GetAssistRules(serverID string) (*AssistRules, bool)
|
||||
return getConfigFromCache(c.assistRules, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetEvent retrieves cached event configuration
|
||||
func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -214,7 +191,6 @@ func (c *ServerConfigCache) GetEvent(serverID string) (*EventConfig, bool) {
|
||||
return getConfigFromCache(c.event, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetEventRules retrieves cached event rules
|
||||
func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -222,7 +198,6 @@ func (c *ServerConfigCache) GetEventRules(serverID string) (*EventRules, bool) {
|
||||
return getConfigFromCache(c.eventRules, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// GetSettings retrieves cached server settings
|
||||
func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
@@ -230,7 +205,6 @@ func (c *ServerConfigCache) GetSettings(serverID string) (*ServerSettings, bool)
|
||||
return getConfigFromCache(c.settings, serverID, c.config.ExpirationTime)
|
||||
}
|
||||
|
||||
// UpdateConfiguration updates the configuration cache
|
||||
func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configuration) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -238,7 +212,6 @@ func (c *ServerConfigCache) UpdateConfiguration(serverID string, config Configur
|
||||
updateConfigInCache(c.configuration, serverID, config)
|
||||
}
|
||||
|
||||
// UpdateAssistRules updates the assist rules cache
|
||||
func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -246,7 +219,6 @@ func (c *ServerConfigCache) UpdateAssistRules(serverID string, rules AssistRules
|
||||
updateConfigInCache(c.assistRules, serverID, rules)
|
||||
}
|
||||
|
||||
// UpdateEvent updates the event configuration cache
|
||||
func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -254,7 +226,6 @@ func (c *ServerConfigCache) UpdateEvent(serverID string, event EventConfig) {
|
||||
updateConfigInCache(c.event, serverID, event)
|
||||
}
|
||||
|
||||
// UpdateEventRules updates the event rules cache
|
||||
func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -262,7 +233,6 @@ func (c *ServerConfigCache) UpdateEventRules(serverID string, rules EventRules)
|
||||
updateConfigInCache(c.eventRules, serverID, rules)
|
||||
}
|
||||
|
||||
// UpdateSettings updates the server settings cache
|
||||
func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSettings) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -270,7 +240,6 @@ func (c *ServerConfigCache) UpdateSettings(serverID string, settings ServerSetti
|
||||
updateConfigInCache(c.settings, serverID, settings)
|
||||
}
|
||||
|
||||
// InvalidateServerCache removes all cached configurations for a specific server
|
||||
func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
@@ -283,7 +252,6 @@ func (c *ServerConfigCache) InvalidateServerCache(serverID string) {
|
||||
delete(c.settings, serverID)
|
||||
}
|
||||
|
||||
// Clear removes all entries from the cache
|
||||
func (c *ServerConfigCache) Clear() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
@@ -13,17 +13,15 @@ import (
|
||||
type IntString int
|
||||
type IntBool int
|
||||
|
||||
// Config tracks configuration modifications
|
||||
type Config struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
|
||||
ConfigFile string `json:"configFile" gorm:"not null"` // e.g. "settings.json"
|
||||
ConfigFile string `json:"configFile" gorm:"not null"`
|
||||
OldConfig string `json:"oldConfig" gorm:"type:text"`
|
||||
NewConfig string `json:"newConfig" gorm:"type:text"`
|
||||
ChangedAt time.Time `json:"changedAt" gorm:"default:CURRENT_TIMESTAMP"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new config entries
|
||||
func (c *Config) BeforeCreate(tx *gorm.DB) error {
|
||||
if c.ID == uuid.Nil {
|
||||
c.ID = uuid.New()
|
||||
@@ -79,11 +77,11 @@ type EventConfig struct {
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
HourOfDay IntString `json:"hourOfDay"`
|
||||
DayOfWeekend IntString `json:"dayOfWeekend"`
|
||||
TimeMultiplier IntString `json:"timeMultiplier"`
|
||||
SessionType string `json:"sessionType"`
|
||||
SessionDurationMinutes IntString `json:"sessionDurationMinutes"`
|
||||
HourOfDay IntString `json:"hourOfDay"`
|
||||
DayOfWeekend IntString `json:"dayOfWeekend"`
|
||||
TimeMultiplier IntString `json:"timeMultiplier"`
|
||||
SessionType TrackSession `json:"sessionType"`
|
||||
SessionDurationMinutes IntString `json:"sessionDurationMinutes"`
|
||||
}
|
||||
|
||||
type AssistRules struct {
|
||||
@@ -121,8 +119,6 @@ type Configuration struct {
|
||||
ConfigVersion IntString `json:"configVersion"`
|
||||
}
|
||||
|
||||
// Known configuration keys
|
||||
|
||||
func (i *IntBool) UnmarshalJSON(b []byte) error {
|
||||
var str int
|
||||
if err := json.Unmarshal(b, &str); err == nil && str <= 1 {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseFilter contains common filter fields that can be embedded in other filters
|
||||
type BaseFilter struct {
|
||||
Page int `query:"page"`
|
||||
PageSize int `query:"page_size"`
|
||||
@@ -15,18 +14,15 @@ type BaseFilter struct {
|
||||
SortDesc bool `query:"sort_desc"`
|
||||
}
|
||||
|
||||
// DateRangeFilter adds date range filtering capabilities
|
||||
type DateRangeFilter struct {
|
||||
StartDate time.Time `query:"start_date" time_format:"2006-01-02T15:04:05Z07:00"`
|
||||
EndDate time.Time `query:"end_date" time_format:"2006-01-02T15:04:05Z07:00"`
|
||||
}
|
||||
|
||||
// ServerBasedFilter adds server ID filtering capability
|
||||
type ServerBasedFilter struct {
|
||||
ServerID string `param:"id"`
|
||||
}
|
||||
|
||||
// ConfigFilter defines filtering options for Config queries
|
||||
type ConfigFilter struct {
|
||||
BaseFilter
|
||||
ServerBasedFilter
|
||||
@@ -34,13 +30,11 @@ type ConfigFilter struct {
|
||||
ChangedAt time.Time `query:"changed_at" time_format:"2006-01-02T15:04:05Z07:00"`
|
||||
}
|
||||
|
||||
// ApiFilter defines filtering options for Api queries
|
||||
type ServiceControlFilter struct {
|
||||
BaseFilter
|
||||
ServiceControl string `query:"serviceControl"`
|
||||
}
|
||||
|
||||
// MembershipFilter defines filtering options for User queries
|
||||
type MembershipFilter struct {
|
||||
BaseFilter
|
||||
Username string `query:"username"`
|
||||
@@ -48,36 +42,32 @@ type MembershipFilter struct {
|
||||
RoleID string `query:"role_id"`
|
||||
}
|
||||
|
||||
// Pagination returns the offset and limit for database queries
|
||||
func (f *BaseFilter) Pagination() (offset, limit int) {
|
||||
if f.Page < 1 {
|
||||
f.Page = 1
|
||||
}
|
||||
if f.PageSize < 1 {
|
||||
f.PageSize = 10 // Default page size
|
||||
f.PageSize = 10
|
||||
}
|
||||
offset = (f.Page - 1) * f.PageSize
|
||||
limit = f.PageSize
|
||||
return
|
||||
}
|
||||
|
||||
// GetSorting returns the sort field and direction for database queries
|
||||
func (f *BaseFilter) GetSorting() (field string, desc bool) {
|
||||
if f.SortBy == "" {
|
||||
return "id", false // Default sorting
|
||||
return "id", false
|
||||
}
|
||||
return f.SortBy, f.SortDesc
|
||||
}
|
||||
|
||||
// IsDateRangeValid checks if both dates are set and start date is before end date
|
||||
func (f *DateRangeFilter) IsDateRangeValid() bool {
|
||||
if f.StartDate.IsZero() || f.EndDate.IsZero() {
|
||||
return true // If either date is not set, consider it valid
|
||||
return true
|
||||
}
|
||||
return f.StartDate.Before(f.EndDate)
|
||||
}
|
||||
|
||||
// ApplyFilter applies the membership filter to a GORM query
|
||||
func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
if f.Username != "" {
|
||||
query = query.Where("username LIKE ?", "%"+f.Username+"%")
|
||||
@@ -93,12 +83,10 @@ func (f *MembershipFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
return query
|
||||
}
|
||||
|
||||
// Pagination returns the offset and limit for database queries
|
||||
func (f *MembershipFilter) Pagination() (offset, limit int) {
|
||||
return f.BaseFilter.Pagination()
|
||||
}
|
||||
|
||||
// GetSorting returns the sort field and direction for database queries
|
||||
func (f *MembershipFilter) GetSorting() (field string, desc bool) {
|
||||
return f.BaseFilter.GetSorting()
|
||||
}
|
||||
|
||||
@@ -1,31 +1,26 @@
|
||||
package model
|
||||
|
||||
// Track represents a track and its capacity
|
||||
type Track struct {
|
||||
Name string `json:"track" gorm:"primaryKey;size:50"`
|
||||
UniquePitBoxes int `json:"unique_pit_boxes"`
|
||||
PrivateServerSlots int `json:"private_server_slots"`
|
||||
}
|
||||
|
||||
// CarModel represents a car model mapping
|
||||
type CarModel struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
CarModel string `json:"car_model"`
|
||||
}
|
||||
|
||||
// DriverCategory represents driver skill categories
|
||||
type DriverCategory struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
Category string `json:"category"`
|
||||
}
|
||||
|
||||
// CupCategory represents championship cup categories
|
||||
type CupCategory struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
Category string `json:"category"`
|
||||
}
|
||||
|
||||
// SessionType represents session types
|
||||
type SessionType struct {
|
||||
Value int `json:"value" gorm:"primaryKey"`
|
||||
SessionType string `json:"session_type"`
|
||||
|
||||
@@ -32,8 +32,6 @@ type BaseModel struct {
|
||||
DateUpdated time.Time `json:"dateUpdated"`
|
||||
}
|
||||
|
||||
// Init
|
||||
// Initializes base model with DateCreated, DateUpdated, and Id values.
|
||||
func (cm *BaseModel) Init() {
|
||||
date := time.Now()
|
||||
cm.Id = uuid.NewString()
|
||||
|
||||
@@ -5,15 +5,13 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Permission represents an action that can be performed in the system.
|
||||
type Permission struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
Name string `json:"name" gorm:"unique_index;not null"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new credentials
|
||||
func (s *Permission) BeforeCreate(tx *gorm.DB) error {
|
||||
s.ID = uuid.New()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package model
|
||||
|
||||
// Permission constants
|
||||
const (
|
||||
ServerView = "server.view"
|
||||
ServerCreate = "server.create"
|
||||
@@ -27,7 +26,6 @@ const (
|
||||
MembershipEdit = "membership.edit"
|
||||
)
|
||||
|
||||
// AllPermissions returns a slice of all permission strings.
|
||||
func AllPermissions() []string {
|
||||
return []string{
|
||||
ServerView,
|
||||
|
||||
@@ -5,16 +5,14 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Role represents a user role in the system.
|
||||
type Role struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
Name string `json:"name" gorm:"unique_index;not null"`
|
||||
Permissions []Permission `json:"permissions" gorm:"many2many:role_permissions;"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new credentials
|
||||
func (s *Role) BeforeCreate(tx *gorm.DB) error {
|
||||
s.ID = uuid.New()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ const (
|
||||
ServiceNamePrefix = "ACC-Server"
|
||||
)
|
||||
|
||||
// Server represents an ACC server instance
|
||||
type ServerAPI struct {
|
||||
Name string `json:"name"`
|
||||
Status ServiceStatus `json:"status"`
|
||||
@@ -35,28 +34,27 @@ func (s *Server) ToServerAPI() *ServerAPI {
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents an ACC server instance
|
||||
type Server struct {
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Status ServiceStatus `json:"status" gorm:"-"`
|
||||
IP string `gorm:"not null" json:"-"`
|
||||
Port int `gorm:"not null" json:"-"`
|
||||
Path string `gorm:"not null" json:"path"` // e.g. "/acc/servers/server1/"
|
||||
ServiceName string `gorm:"not null" json:"serviceName"` // Windows service name
|
||||
Path string `gorm:"not null" json:"path"`
|
||||
ServiceName string `gorm:"not null" json:"serviceName"`
|
||||
State *ServerState `gorm:"-" json:"state"`
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
FromSteamCMD bool `gorm:"not null; default:true" json:"-"`
|
||||
}
|
||||
|
||||
type PlayerState struct {
|
||||
CarID int // Car ID in broadcast packets
|
||||
DriverName string // Optional: pulled from registration packet
|
||||
CarID int
|
||||
DriverName string
|
||||
TeamName string
|
||||
CarModel string
|
||||
CurrentLap int
|
||||
LastLapTime int // in milliseconds
|
||||
BestLapTime int // in milliseconds
|
||||
LastLapTime int
|
||||
BestLapTime int
|
||||
Position int
|
||||
ConnectedAt time.Time
|
||||
DisconnectedAt *time.Time
|
||||
@@ -67,23 +65,18 @@ type State struct {
|
||||
Session string `json:"session"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
PlayerCount int `json:"playerCount"`
|
||||
// Players map[int]*PlayerState
|
||||
// etc.
|
||||
}
|
||||
|
||||
type ServerState struct {
|
||||
sync.RWMutex `swaggerignore:"-" json:"-"`
|
||||
Session string `json:"session"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
PlayerCount int `json:"playerCount"`
|
||||
Track string `json:"track"`
|
||||
MaxConnections int `json:"maxConnections"`
|
||||
SessionDurationMinutes int `json:"sessionDurationMinutes"`
|
||||
// Players map[int]*PlayerState
|
||||
// etc.
|
||||
Session TrackSession `json:"session"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
PlayerCount int `json:"playerCount"`
|
||||
Track string `json:"track"`
|
||||
MaxConnections int `json:"maxConnections"`
|
||||
SessionDurationMinutes int `json:"sessionDurationMinutes"`
|
||||
}
|
||||
|
||||
// ServerFilter defines filtering options for Server queries
|
||||
type ServerFilter struct {
|
||||
BaseFilter
|
||||
ServerBasedFilter
|
||||
@@ -92,9 +85,7 @@ type ServerFilter struct {
|
||||
Status string `query:"status"`
|
||||
}
|
||||
|
||||
// ApplyFilter implements the Filterable interface
|
||||
func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
// Apply server filter
|
||||
if f.ServerID != "" {
|
||||
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
|
||||
query = query.Where("id = ?", serverUUID)
|
||||
@@ -104,18 +95,19 @@ func (f *ServerFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
return query
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating a new server
|
||||
func (s *Server) GenerateUUID() {
|
||||
if s.ID == uuid.Nil {
|
||||
s.ID = uuid.New()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.Name == "" {
|
||||
return errors.New("server name is required")
|
||||
}
|
||||
|
||||
// Generate UUID if not set
|
||||
if s.ID == uuid.Nil {
|
||||
s.ID = uuid.New()
|
||||
}
|
||||
s.GenerateUUID()
|
||||
|
||||
// Generate service name and config path if not set
|
||||
if s.ServiceName == "" {
|
||||
s.ServiceName = s.GenerateServiceName()
|
||||
}
|
||||
@@ -123,7 +115,6 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
|
||||
s.Path = s.GenerateServerPath(BaseServerPath)
|
||||
}
|
||||
|
||||
// Set creation date if not set
|
||||
if s.DateCreated.IsZero() {
|
||||
s.DateCreated = time.Now().UTC()
|
||||
}
|
||||
@@ -131,19 +122,14 @@ func (s *Server) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateServiceName creates a unique service name based on the server name
|
||||
func (s *Server) GenerateServiceName() string {
|
||||
// If ID is set, use it
|
||||
if s.ID != uuid.Nil {
|
||||
return fmt.Sprintf("%s-%s", ServiceNamePrefix, s.ID.String()[:8])
|
||||
}
|
||||
// Otherwise use a timestamp-based unique identifier
|
||||
return fmt.Sprintf("%s-%d", ServiceNamePrefix, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// GenerateServerPath creates the config path based on the service name
|
||||
func (s *Server) GenerateServerPath(steamCMDPath string) string {
|
||||
// Ensure service name is set
|
||||
if s.ServiceName == "" {
|
||||
s.ServiceName = s.GenerateServiceName()
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ const (
|
||||
StatusRunning
|
||||
)
|
||||
|
||||
// String converts the ServiceStatus to its string representation
|
||||
func (s ServiceStatus) String() string {
|
||||
switch s {
|
||||
case StatusRunning:
|
||||
@@ -35,7 +34,6 @@ func (s ServiceStatus) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseServiceStatus converts a string to ServiceStatus
|
||||
func ParseServiceStatus(s string) ServiceStatus {
|
||||
switch s {
|
||||
case "SERVICE_RUNNING":
|
||||
@@ -53,31 +51,24 @@ func ParseServiceStatus(s string) ServiceStatus {
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler interface
|
||||
func (s ServiceStatus) MarshalJSON() ([]byte, error) {
|
||||
// Return the numeric value instead of string
|
||||
return []byte(strconv.Itoa(int(s))), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler interface
|
||||
func (s *ServiceStatus) UnmarshalJSON(data []byte) error {
|
||||
// Try to parse as number first
|
||||
if i, err := strconv.Atoi(string(data)); err == nil {
|
||||
*s = ServiceStatus(i)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback to string parsing for backward compatibility
|
||||
str := string(data)
|
||||
if len(str) >= 2 {
|
||||
// Remove quotes if present
|
||||
str = str[1 : len(str)-1]
|
||||
}
|
||||
*s = ParseServiceStatus(str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface
|
||||
func (s *ServiceStatus) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*s = StatusUnknown
|
||||
@@ -99,7 +90,6 @@ func (s *ServiceStatus) Scan(value interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface
|
||||
func (s ServiceStatus) Value() (driver.Value, error) {
|
||||
return s.String(), nil
|
||||
}
|
||||
|
||||
@@ -1,33 +1,31 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// StateHistoryFilter combines common filter capabilities
|
||||
type StateHistoryFilter struct {
|
||||
ServerBasedFilter // Adds server ID from path parameter
|
||||
DateRangeFilter // Adds date range filtering
|
||||
ServerBasedFilter
|
||||
DateRangeFilter
|
||||
|
||||
// Additional fields specific to state history
|
||||
Session string `query:"session"`
|
||||
MinPlayers *int `query:"min_players"`
|
||||
MaxPlayers *int `query:"max_players"`
|
||||
Session TrackSession `query:"session"`
|
||||
MinPlayers *int `query:"min_players"`
|
||||
MaxPlayers *int `query:"max_players"`
|
||||
}
|
||||
|
||||
// ApplyFilter implements the Filterable interface
|
||||
func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
// Apply server filter
|
||||
if f.ServerID != "" {
|
||||
if serverUUID, err := uuid.Parse(f.ServerID); err == nil {
|
||||
query = query.Where("server_id = ?", serverUUID)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply date range filter if set
|
||||
timeZero := time.Time{}
|
||||
if f.StartDate != timeZero {
|
||||
query = query.Where("date_created >= ?", f.StartDate)
|
||||
@@ -36,12 +34,10 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
query = query.Where("date_created <= ?", f.EndDate)
|
||||
}
|
||||
|
||||
// Apply session filter if set
|
||||
if f.Session != "" {
|
||||
query = query.Where("session = ?", f.Session)
|
||||
}
|
||||
|
||||
// Apply player count filters if set
|
||||
if f.MinPlayers != nil {
|
||||
query = query.Where("player_count >= ?", *f.MinPlayers)
|
||||
}
|
||||
@@ -52,19 +48,68 @@ func (f *StateHistoryFilter) ApplyFilter(query *gorm.DB) *gorm.DB {
|
||||
return query
|
||||
}
|
||||
|
||||
type StateHistory struct {
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
|
||||
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
|
||||
Session string `json:"session"`
|
||||
Track string `json:"track"`
|
||||
PlayerCount int `json:"playerCount"`
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
SessionDurationMinutes int `json:"sessionDurationMinutes"`
|
||||
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"` // Unique identifier for each session/event
|
||||
type TrackSession string
|
||||
|
||||
const (
|
||||
SessionPractice TrackSession = "P"
|
||||
SessionQualify TrackSession = "Q"
|
||||
SessionRace TrackSession = "R"
|
||||
SessionUnknown TrackSession = "U"
|
||||
)
|
||||
|
||||
func (i *TrackSession) UnmarshalJSON(b []byte) error {
|
||||
var str string
|
||||
if err := json.Unmarshal(b, &str); err == nil {
|
||||
*i = ToTrackSession(str)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid TrackSession value")
|
||||
}
|
||||
|
||||
func (i TrackSession) Humanize() string {
|
||||
switch i {
|
||||
case SessionPractice:
|
||||
return "Practice"
|
||||
case SessionQualify:
|
||||
return "Qualifying"
|
||||
case SessionRace:
|
||||
return "Race"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func ToTrackSession(i string) TrackSession {
|
||||
sessionAbrv := strings.ToUpper(i[:1])
|
||||
switch sessionAbrv {
|
||||
case "P":
|
||||
return SessionPractice
|
||||
case "Q":
|
||||
return SessionQualify
|
||||
case "R":
|
||||
return SessionRace
|
||||
default:
|
||||
return SessionUnknown
|
||||
}
|
||||
}
|
||||
|
||||
func (i TrackSession) ToString() string {
|
||||
return string(i)
|
||||
}
|
||||
|
||||
type StateHistory struct {
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
|
||||
ServerID uuid.UUID `json:"serverId" gorm:"not null;type:uuid"`
|
||||
Session TrackSession `json:"session"`
|
||||
Track string `json:"track"`
|
||||
PlayerCount int `json:"playerCount"`
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
SessionStart time.Time `json:"sessionStart"`
|
||||
SessionDurationMinutes int `json:"sessionDurationMinutes"`
|
||||
SessionID uuid.UUID `json:"sessionId" gorm:"not null;type:uuid"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new state history entries
|
||||
func (sh *StateHistory) BeforeCreate(tx *gorm.DB) error {
|
||||
if sh.ID == uuid.Nil {
|
||||
sh.ID = uuid.New()
|
||||
|
||||
@@ -1,36 +1,38 @@
|
||||
package model
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type SessionCount struct {
|
||||
Name string `json:"name"`
|
||||
Count int `json:"count"`
|
||||
Name TrackSession `json:"name"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type DailyActivity struct {
|
||||
Date string `json:"date"`
|
||||
SessionsCount int `json:"sessionsCount"`
|
||||
SessionsCount int `json:"sessionsCount"`
|
||||
}
|
||||
|
||||
type PlayerCountPoint struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
Count float64 `json:"count"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Count float64 `json:"count"`
|
||||
}
|
||||
|
||||
type StateHistoryStats struct {
|
||||
AveragePlayers float64 `json:"averagePlayers"`
|
||||
PeakPlayers int `json:"peakPlayers"`
|
||||
TotalSessions int `json:"totalSessions"`
|
||||
TotalPlaytime int `json:"totalPlaytime" gorm:"-"` // in minutes
|
||||
AveragePlayers float64 `json:"averagePlayers"`
|
||||
PeakPlayers int `json:"peakPlayers"`
|
||||
TotalSessions int `json:"totalSessions"`
|
||||
TotalPlaytime int `json:"totalPlaytime" gorm:"-"`
|
||||
PlayerCountOverTime []PlayerCountPoint `json:"playerCountOverTime" gorm:"-"`
|
||||
SessionTypes []SessionCount `json:"sessionTypes" gorm:"-"`
|
||||
DailyActivity []DailyActivity `json:"dailyActivity" gorm:"-"`
|
||||
RecentSessions []RecentSession `json:"recentSessions" gorm:"-"`
|
||||
}
|
||||
}
|
||||
|
||||
type RecentSession struct {
|
||||
ID uint `json:"id"`
|
||||
Date string `json:"date"`
|
||||
Type string `json:"type"`
|
||||
Track string `json:"track"`
|
||||
Duration int `json:"duration"`
|
||||
Players int `json:"players"`
|
||||
}
|
||||
ID uuid.UUID `json:"id"`
|
||||
Date string `json:"date"`
|
||||
Type TrackSession `json:"type"`
|
||||
Track string `json:"track"`
|
||||
Duration int `json:"duration"`
|
||||
Players int `json:"players"`
|
||||
}
|
||||
|
||||
168
local/model/steam_2fa.go
Normal file
168
local/model/steam_2fa.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Steam2FAStatus string
|
||||
|
||||
const (
|
||||
Steam2FAStatusIdle Steam2FAStatus = "idle"
|
||||
Steam2FAStatusPending Steam2FAStatus = "pending"
|
||||
Steam2FAStatusComplete Steam2FAStatus = "complete"
|
||||
Steam2FAStatusError Steam2FAStatus = "error"
|
||||
)
|
||||
|
||||
type Steam2FARequest struct {
|
||||
ID string `json:"id"`
|
||||
Status Steam2FAStatus `json:"status"`
|
||||
Message string `json:"message"`
|
||||
RequestTime time.Time `json:"requestTime"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
ServerID *uuid.UUID `json:"serverId,omitempty"`
|
||||
}
|
||||
|
||||
// Steam2FAManager manages 2FA requests and responses
|
||||
type Steam2FAManager struct {
|
||||
mu sync.RWMutex
|
||||
requests map[string]*Steam2FARequest
|
||||
channels map[string]chan bool
|
||||
}
|
||||
|
||||
func NewSteam2FAManager() *Steam2FAManager {
|
||||
return &Steam2FAManager{
|
||||
requests: make(map[string]*Steam2FARequest),
|
||||
channels: make(map[string]chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) CreateRequest(message string, serverID *uuid.UUID) *Steam2FARequest {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
request := &Steam2FARequest{
|
||||
ID: id,
|
||||
Status: Steam2FAStatusPending,
|
||||
Message: message,
|
||||
RequestTime: time.Now(),
|
||||
ServerID: serverID,
|
||||
}
|
||||
|
||||
m.requests[id] = request
|
||||
m.channels[id] = make(chan bool, 1)
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) GetRequest(id string) (*Steam2FARequest, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
return req, exists
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) GetPendingRequests() []*Steam2FARequest {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var pending []*Steam2FARequest
|
||||
for _, req := range m.requests {
|
||||
if req.Status == Steam2FAStatusPending {
|
||||
pending = append(pending, req)
|
||||
}
|
||||
}
|
||||
return pending
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) CompleteRequest(id string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("request %s not found", id)
|
||||
}
|
||||
|
||||
if req.Status != Steam2FAStatusPending {
|
||||
return fmt.Errorf("request %s is not pending", id)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
req.Status = Steam2FAStatusComplete
|
||||
req.CompletedAt = &now
|
||||
|
||||
// Signal the waiting goroutine
|
||||
if ch, exists := m.channels[id]; exists {
|
||||
select {
|
||||
case ch <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) ErrorRequest(id string, errorMsg string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("request %s not found", id)
|
||||
}
|
||||
|
||||
req.Status = Steam2FAStatusError
|
||||
req.ErrorMsg = errorMsg
|
||||
|
||||
// Signal the waiting goroutine with error
|
||||
if ch, exists := m.channels[id]; exists {
|
||||
select {
|
||||
case ch <- false:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) WaitForCompletion(id string, timeout time.Duration) (bool, error) {
|
||||
m.mu.RLock()
|
||||
ch, exists := m.channels[id]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, fmt.Errorf("request %s not found", id)
|
||||
}
|
||||
|
||||
select {
|
||||
case success := <-ch:
|
||||
return success, nil
|
||||
case <-time.After(timeout):
|
||||
// Timeout - mark as error
|
||||
m.ErrorRequest(id, "timeout waiting for 2FA confirmation")
|
||||
return false, fmt.Errorf("timeout waiting for 2FA confirmation")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Steam2FAManager) CleanupOldRequests(maxAge time.Duration) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
for id, req := range m.requests {
|
||||
if req.RequestTime.Before(cutoff) {
|
||||
delete(m.requests, id)
|
||||
if ch, exists := m.channels[id]; exists {
|
||||
close(ch)
|
||||
delete(m.channels, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -16,21 +16,18 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SteamCredentials represents stored Steam login credentials
|
||||
type SteamCredentials struct {
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;" json:"id"`
|
||||
Username string `gorm:"not null" json:"username"`
|
||||
Password string `gorm:"not null" json:"-"` // Encrypted, not exposed in JSON
|
||||
Password string `gorm:"not null" json:"-"`
|
||||
DateCreated time.Time `json:"dateCreated"`
|
||||
LastUpdated time.Time `json:"lastUpdated"`
|
||||
}
|
||||
|
||||
// TableName specifies the table name for GORM
|
||||
func (SteamCredentials) TableName() string {
|
||||
return "steam_credentials"
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new credentials
|
||||
func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.ID == uuid.Nil {
|
||||
s.ID = uuid.New()
|
||||
@@ -42,7 +39,6 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
|
||||
}
|
||||
s.LastUpdated = now
|
||||
|
||||
// Encrypt password before saving
|
||||
encrypted, err := EncryptPassword(s.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -52,11 +48,9 @@ func (s *SteamCredentials) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeUpdate is a GORM hook that runs before updating credentials
|
||||
func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
|
||||
s.LastUpdated = time.Now().UTC()
|
||||
|
||||
// Only encrypt if password field is being updated
|
||||
if tx.Statement.Changed("Password") {
|
||||
encrypted, err := EncryptPassword(s.Password)
|
||||
if err != nil {
|
||||
@@ -68,9 +62,7 @@ func (s *SteamCredentials) BeforeUpdate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that runs after fetching credentials
|
||||
func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
|
||||
// Decrypt password after fetching
|
||||
if s.Password != "" {
|
||||
decrypted, err := DecryptPassword(s.Password)
|
||||
if err != nil {
|
||||
@@ -81,18 +73,15 @@ func (s *SteamCredentials) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the credentials are valid with enhanced security checks
|
||||
func (s *SteamCredentials) Validate() error {
|
||||
if s.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
|
||||
// Enhanced username validation
|
||||
if len(s.Username) < 3 || len(s.Username) > 64 {
|
||||
return errors.New("username must be between 3 and 64 characters")
|
||||
}
|
||||
|
||||
// Check for valid characters in username (alphanumeric, underscore, hyphen)
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z0-9_-]+$`, s.Username); !matched {
|
||||
return errors.New("username contains invalid characters")
|
||||
}
|
||||
@@ -101,7 +90,6 @@ func (s *SteamCredentials) Validate() error {
|
||||
return errors.New("password is required")
|
||||
}
|
||||
|
||||
// Basic password validation
|
||||
if len(s.Password) < 6 {
|
||||
return errors.New("password must be at least 6 characters long")
|
||||
}
|
||||
@@ -110,7 +98,6 @@ func (s *SteamCredentials) Validate() error {
|
||||
return errors.New("password is too long")
|
||||
}
|
||||
|
||||
// Check for obvious weak passwords
|
||||
weakPasswords := []string{"password", "123456", "steam", "admin", "user"}
|
||||
lowerPass := strings.ToLower(s.Password)
|
||||
for _, weak := range weakPasswords {
|
||||
@@ -122,8 +109,6 @@ func (s *SteamCredentials) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEncryptionKey returns the encryption key from config.
|
||||
// The key is loaded from the ENCRYPTION_KEY environment variable.
|
||||
func GetEncryptionKey() []byte {
|
||||
key := []byte(configs.EncryptionKey)
|
||||
if len(key) != 32 {
|
||||
@@ -132,7 +117,6 @@ func GetEncryptionKey() []byte {
|
||||
return key
|
||||
}
|
||||
|
||||
// EncryptPassword encrypts a password using AES-256-GCM with enhanced security
|
||||
func EncryptPassword(password string) (string, error) {
|
||||
if password == "" {
|
||||
return "", errors.New("password cannot be empty")
|
||||
@@ -148,33 +132,27 @@ func EncryptPassword(password string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a new GCM cipher
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a cryptographically secure nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Encrypt the password with authenticated encryption
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(password), nil)
|
||||
|
||||
// Return base64 encoded encrypted password
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptPassword decrypts an encrypted password with enhanced validation
|
||||
func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", errors.New("encrypted password cannot be empty")
|
||||
}
|
||||
|
||||
// Validate base64 format
|
||||
if len(encryptedPassword) < 24 { // Minimum reasonable length
|
||||
if len(encryptedPassword) < 24 {
|
||||
return "", errors.New("invalid encrypted password format")
|
||||
}
|
||||
|
||||
@@ -184,13 +162,11 @@ func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create a new GCM cipher
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Decode base64 encoded password
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedPassword)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid base64 encoding")
|
||||
@@ -207,7 +183,6 @@ func DecryptPassword(encryptedPassword string) (string, error) {
|
||||
return "", errors.New("decryption failed - invalid ciphertext or key")
|
||||
}
|
||||
|
||||
// Validate decrypted content
|
||||
decrypted := string(plaintext)
|
||||
if len(decrypted) == 0 || len(decrypted) > 1024 {
|
||||
return "", errors.New("invalid decrypted password")
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User represents a user account in the system.
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id" gorm:"type:uuid;primary_key;"`
|
||||
Username string `json:"username" gorm:"unique_index;not null"`
|
||||
@@ -17,16 +16,13 @@ type User struct {
|
||||
Role Role `json:"role"`
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating new users
|
||||
func (s *User) BeforeCreate(tx *gorm.DB) error {
|
||||
s.ID = uuid.New()
|
||||
|
||||
// Validate password strength
|
||||
if err := password.ValidatePasswordStrength(s.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash password before saving
|
||||
hashed, err := password.HashPassword(s.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -36,11 +32,8 @@ func (s *User) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeUpdate is a GORM hook that runs before updating users
|
||||
func (s *User) BeforeUpdate(tx *gorm.DB) error {
|
||||
// Only hash if password field is being updated
|
||||
if tx.Statement.Changed("Password") {
|
||||
// Validate password strength
|
||||
if err := password.ValidatePasswordStrength(s.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -55,14 +48,10 @@ func (s *User) BeforeUpdate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind is a GORM hook that runs after fetching users
|
||||
func (s *User) AfterFind(tx *gorm.DB) error {
|
||||
// Password remains hashed - never decrypt
|
||||
// This hook is kept for potential future use
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks if the user data is valid
|
||||
func (s *User) Validate() error {
|
||||
if s.Username == "" {
|
||||
return errors.New("username is required")
|
||||
@@ -73,7 +62,6 @@ func (s *User) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a plain text password against the stored hash
|
||||
func (s *User) VerifyPassword(plainPassword string) error {
|
||||
return password.VerifyPassword(s.Password, plainPassword)
|
||||
}
|
||||
|
||||
80
local/model/websocket.go
Normal file
80
local/model/websocket.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ServerCreationStep string
|
||||
|
||||
const (
|
||||
StepValidation ServerCreationStep = "validation"
|
||||
StepDirectoryCreation ServerCreationStep = "directory_creation"
|
||||
StepSteamDownload ServerCreationStep = "steam_download"
|
||||
StepConfigGeneration ServerCreationStep = "config_generation"
|
||||
StepServiceCreation ServerCreationStep = "service_creation"
|
||||
StepFirewallRules ServerCreationStep = "firewall_rules"
|
||||
StepDatabaseSave ServerCreationStep = "database_save"
|
||||
StepCompleted ServerCreationStep = "completed"
|
||||
)
|
||||
|
||||
type StepStatus string
|
||||
|
||||
const (
|
||||
StatusPending StepStatus = "pending"
|
||||
StatusInProgress StepStatus = "in_progress"
|
||||
StatusCompleted StepStatus = "completed"
|
||||
StatusFailed StepStatus = "failed"
|
||||
)
|
||||
|
||||
type WebSocketMessageType string
|
||||
|
||||
const (
|
||||
MessageTypeStep WebSocketMessageType = "step"
|
||||
MessageTypeSteamOutput WebSocketMessageType = "steam_output"
|
||||
MessageTypeError WebSocketMessageType = "error"
|
||||
MessageTypeComplete WebSocketMessageType = "complete"
|
||||
)
|
||||
|
||||
type WebSocketMessage struct {
|
||||
Type WebSocketMessageType `json:"type"`
|
||||
ServerID *uuid.UUID `json:"server_id,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
type StepMessage struct {
|
||||
Step ServerCreationStep `json:"step"`
|
||||
Status StepStatus `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type SteamOutputMessage struct {
|
||||
Output string `json:"output"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
type ErrorMessage struct {
|
||||
Error string `json:"error"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type CompleteMessage struct {
|
||||
ServerID uuid.UUID `json:"server_id"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func GetStepDescription(step ServerCreationStep) string {
|
||||
descriptions := map[ServerCreationStep]string{
|
||||
StepValidation: "Validating server configuration",
|
||||
StepDirectoryCreation: "Creating server directories",
|
||||
StepSteamDownload: "Downloading server files via Steam",
|
||||
StepConfigGeneration: "Generating server configuration files",
|
||||
StepServiceCreation: "Creating Windows service",
|
||||
StepFirewallRules: "Configuring firewall rules",
|
||||
StepDatabaseSave: "Saving server to database",
|
||||
StepCompleted: "Server creation completed",
|
||||
}
|
||||
return descriptions[step]
|
||||
}
|
||||
@@ -7,13 +7,11 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseRepository provides generic CRUD operations for any model
|
||||
type BaseRepository[T any, F any] struct {
|
||||
db *gorm.DB
|
||||
modelType T
|
||||
}
|
||||
|
||||
// NewBaseRepository creates a new base repository for the given model type
|
||||
func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F] {
|
||||
return &BaseRepository[T, F]{
|
||||
db: db,
|
||||
@@ -21,23 +19,19 @@ func NewBaseRepository[T any, F any](db *gorm.DB, model T) *BaseRepository[T, F]
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll retrieves all records based on the filter
|
||||
func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, error) {
|
||||
result := new([]T)
|
||||
query := r.db.WithContext(ctx).Model(&r.modelType)
|
||||
|
||||
// Apply filter conditions if filter implements Filterable
|
||||
if filterable, ok := any(filter).(Filterable); ok {
|
||||
query = filterable.ApplyFilter(query)
|
||||
}
|
||||
|
||||
// Apply pagination if filter implements Pageable
|
||||
if pageable, ok := any(filter).(Pageable); ok {
|
||||
offset, limit := pageable.Pagination()
|
||||
query = query.Offset(offset).Limit(limit)
|
||||
}
|
||||
|
||||
// Apply sorting if filter implements Sortable
|
||||
if sortable, ok := any(filter).(Sortable); ok {
|
||||
field, desc := sortable.GetSorting()
|
||||
if desc {
|
||||
@@ -54,7 +48,6 @@ func (r *BaseRepository[T, F]) GetAll(ctx context.Context, filter *F) (*[]T, err
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves a single record by ID
|
||||
func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T, error) {
|
||||
result := new(T)
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(result).Error; err != nil {
|
||||
@@ -66,7 +59,6 @@ func (r *BaseRepository[T, F]) GetByID(ctx context.Context, id interface{}) (*T,
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Insert creates a new record
|
||||
func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
|
||||
if err := r.db.WithContext(ctx).Create(model).Error; err != nil {
|
||||
return fmt.Errorf("error creating record: %w", err)
|
||||
@@ -74,7 +66,6 @@ func (r *BaseRepository[T, F]) Insert(ctx context.Context, model *T) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update modifies an existing record
|
||||
func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
|
||||
if err := r.db.WithContext(ctx).Save(model).Error; err != nil {
|
||||
return fmt.Errorf("error updating record: %w", err)
|
||||
@@ -82,7 +73,6 @@ func (r *BaseRepository[T, F]) Update(ctx context.Context, model *T) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a record by ID
|
||||
func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error {
|
||||
if err := r.db.WithContext(ctx).Delete(new(T), id).Error; err != nil {
|
||||
return fmt.Errorf("error deleting record: %w", err)
|
||||
@@ -90,7 +80,6 @@ func (r *BaseRepository[T, F]) Delete(ctx context.Context, id interface{}) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the total number of records matching the filter
|
||||
func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&r.modelType)
|
||||
@@ -106,8 +95,6 @@ func (r *BaseRepository[T, F]) Count(ctx context.Context, filter *F) (int64, err
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Interfaces for filter capabilities
|
||||
|
||||
type Filterable interface {
|
||||
ApplyFilter(*gorm.DB) *gorm.DB
|
||||
}
|
||||
|
||||
@@ -17,13 +17,11 @@ func NewConfigRepository(db *gorm.DB) *ConfigRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig updates or creates a Config record
|
||||
func (r *ConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
|
||||
if err := r.Update(ctx, config); err != nil {
|
||||
// If update fails, try to insert
|
||||
if err := r.Insert(ctx, config); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,20 +8,16 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MembershipRepository handles database operations for users, roles, and permissions.
|
||||
type MembershipRepository struct {
|
||||
*BaseRepository[model.User, model.MembershipFilter]
|
||||
}
|
||||
|
||||
// NewMembershipRepository creates a new MembershipRepository.
|
||||
func NewMembershipRepository(db *gorm.DB) *MembershipRepository {
|
||||
return &MembershipRepository{
|
||||
BaseRepository: NewBaseRepository[model.User, model.MembershipFilter](db, model.User{}),
|
||||
}
|
||||
}
|
||||
|
||||
// FindUserByUsername finds a user by their username.
|
||||
// It preloads the user's role and the role's permissions.
|
||||
func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -32,7 +28,6 @@ func (r *MembershipRepository) FindUserByUsername(ctx context.Context, username
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindUserByIDWithPermissions finds a user by their ID and preloads Role and Permissions.
|
||||
func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context, userID string) (*model.User, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -43,13 +38,11 @@ func (r *MembershipRepository) FindUserByIDWithPermissions(ctx context.Context,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
func (r *MembershipRepository) CreateUser(ctx context.Context, user *model.User) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Create(user).Error
|
||||
}
|
||||
|
||||
// FindRoleByName finds a role by its name.
|
||||
func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string) (*model.Role, error) {
|
||||
var role model.Role
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -60,13 +53,11 @@ func (r *MembershipRepository) FindRoleByName(ctx context.Context, name string)
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// CreateRole creates a new role.
|
||||
func (r *MembershipRepository) CreateRole(ctx context.Context, role *model.Role) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Create(role).Error
|
||||
}
|
||||
|
||||
// FindPermissionByName finds a permission by its name.
|
||||
func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name string) (*model.Permission, error) {
|
||||
var permission model.Permission
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -77,19 +68,16 @@ func (r *MembershipRepository) FindPermissionByName(ctx context.Context, name st
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// CreatePermission creates a new permission.
|
||||
func (r *MembershipRepository) CreatePermission(ctx context.Context, permission *model.Permission) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Create(permission).Error
|
||||
}
|
||||
|
||||
// AssignPermissionsToRole assigns a set of permissions to a role.
|
||||
func (r *MembershipRepository) AssignPermissionsToRole(ctx context.Context, role *model.Role, permissions []model.Permission) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Model(role).Association("Permissions").Replace(permissions)
|
||||
}
|
||||
|
||||
// GetUserPermissions retrieves all permissions for a given user ID.
|
||||
func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uuid.UUID) ([]string, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -106,7 +94,6 @@ func (r *MembershipRepository) GetUserPermissions(ctx context.Context, userID uu
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// ListUsers retrieves all users.
|
||||
func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, error) {
|
||||
var users []*model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -114,13 +101,11 @@ func (r *MembershipRepository) ListUsers(ctx context.Context) ([]*model.User, er
|
||||
return users, err
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user.
|
||||
func (r *MembershipRepository) DeleteUser(ctx context.Context, userID uuid.UUID) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Delete(&model.User{}, "id = ?", userID).Error
|
||||
}
|
||||
|
||||
// FindUserByID finds a user by their ID.
|
||||
func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) {
|
||||
var user model.User
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -131,13 +116,11 @@ func (r *MembershipRepository) FindUserByID(ctx context.Context, userID uuid.UUI
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's details in the database.
|
||||
func (r *MembershipRepository) UpdateUser(ctx context.Context, user *model.User) error {
|
||||
db := r.db.WithContext(ctx)
|
||||
return db.Save(user).Error
|
||||
}
|
||||
|
||||
// FindRoleByID finds a role by its ID.
|
||||
func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUID) (*model.Role, error) {
|
||||
var role model.Role
|
||||
db := r.db.WithContext(ctx)
|
||||
@@ -148,12 +131,10 @@ func (r *MembershipRepository) FindRoleByID(ctx context.Context, roleID uuid.UUI
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// ListUsersWithFilter retrieves users based on the membership filter.
|
||||
func (r *MembershipRepository) ListUsersWithFilter(ctx context.Context, filter *model.MembershipFilter) (*[]model.User, error) {
|
||||
return r.BaseRepository.GetAll(ctx, filter)
|
||||
}
|
||||
|
||||
// ListRoles retrieves all roles.
|
||||
func (r *MembershipRepository) ListRoles(ctx context.Context) ([]*model.Role, error) {
|
||||
var roles []*model.Role
|
||||
db := r.db.WithContext(ctx)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/graceful"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// InitializeRepositories
|
||||
// Initializes Dependency Injection modules for repositories
|
||||
//
|
||||
// Args:
|
||||
// *dig.Container: Dig Container
|
||||
// *dig.Container: Dig Container
|
||||
func InitializeRepositories(c *dig.Container) {
|
||||
c.Provide(NewServiceControlRepository)
|
||||
c.Provide(NewStateHistoryRepository)
|
||||
@@ -17,4 +19,27 @@ func InitializeRepositories(c *dig.Container) {
|
||||
c.Provide(NewLookupRepository)
|
||||
c.Provide(NewSteamCredentialsRepository)
|
||||
c.Provide(NewMembershipRepository)
|
||||
|
||||
if err := c.Provide(func() *model.Steam2FAManager {
|
||||
manager := model.NewSteam2FAManager()
|
||||
|
||||
shutdownManager := graceful.GetManager()
|
||||
shutdownManager.RunGoroutine(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
manager.CleanupOldRequests(30 * time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return manager
|
||||
}); err != nil {
|
||||
logging.Panic("unable to initialize steam 2fa manager")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,10 +20,6 @@ func NewServerRepository(db *gorm.DB) *ServerRepository {
|
||||
return repo
|
||||
}
|
||||
|
||||
// GetFirstByServiceName
|
||||
// Gets first row from Server table.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// model.ServerModel: Server object from database.
|
||||
|
||||
@@ -18,17 +18,14 @@ func NewStateHistoryRepository(db *gorm.DB) *StateHistoryRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll retrieves all state history records with the given filter
|
||||
func (r *StateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
|
||||
return r.BaseRepository.GetAll(ctx, filter)
|
||||
}
|
||||
|
||||
// Insert creates a new state history record
|
||||
func (r *StateHistoryRepository) Insert(ctx context.Context, model *model.StateHistory) error {
|
||||
return r.BaseRepository.Insert(ctx, model)
|
||||
}
|
||||
|
||||
// GetLastSessionID gets the last session ID for a server
|
||||
func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
|
||||
var lastSession model.StateHistory
|
||||
result := r.BaseRepository.db.WithContext(ctx).
|
||||
@@ -38,7 +35,7 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
|
||||
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return uuid.Nil, nil // Return nil UUID if no sessions found
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
return uuid.Nil, result.Error
|
||||
}
|
||||
@@ -46,10 +43,8 @@ func (r *StateHistoryRepository) GetLastSessionID(ctx context.Context, serverID
|
||||
return lastSession.SessionID, nil
|
||||
}
|
||||
|
||||
// GetSummaryStats calculates peak players, total sessions, and average players.
|
||||
func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
|
||||
var stats model.StateHistoryStats
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return model.StateHistoryStats{}, err
|
||||
@@ -73,12 +68,10 @@ func (r *StateHistoryRepository) GetSummaryStats(ctx context.Context, filter *mo
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetTotalPlaytime calculates the total playtime in minutes.
|
||||
func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
|
||||
var totalPlaytime struct {
|
||||
TotalMinutes float64
|
||||
}
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -100,10 +93,8 @@ func (r *StateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *m
|
||||
return int(totalPlaytime.TotalMinutes), nil
|
||||
}
|
||||
|
||||
// GetPlayerCountOverTime gets downsampled player count data.
|
||||
func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
|
||||
var points []model.PlayerCountPoint
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return points, err
|
||||
@@ -122,10 +113,8 @@ func (r *StateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, fil
|
||||
return points, err
|
||||
}
|
||||
|
||||
// GetSessionTypes counts sessions by type.
|
||||
func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
|
||||
var sessionTypes []model.SessionCount
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return sessionTypes, err
|
||||
@@ -145,10 +134,8 @@ func (r *StateHistoryRepository) GetSessionTypes(ctx context.Context, filter *mo
|
||||
return sessionTypes, err
|
||||
}
|
||||
|
||||
// GetDailyActivity counts sessions per day.
|
||||
func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
|
||||
var dailyActivity []model.DailyActivity
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return dailyActivity, err
|
||||
@@ -167,10 +154,8 @@ func (r *StateHistoryRepository) GetDailyActivity(ctx context.Context, filter *m
|
||||
return dailyActivity, err
|
||||
}
|
||||
|
||||
// GetRecentSessions retrieves the 10 most recent sessions.
|
||||
func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
|
||||
var recentSessions []model.RecentSession
|
||||
// Parse ServerID to UUID for query
|
||||
serverUUID, err := uuid.Parse(filter.ServerID)
|
||||
if err != nil {
|
||||
return recentSessions, err
|
||||
@@ -187,7 +172,7 @@ func (r *StateHistoryRepository) GetRecentSessions(ctx context.Context, filter *
|
||||
FROM state_histories
|
||||
WHERE server_id = ? AND date_created BETWEEN ? AND ?
|
||||
GROUP BY session_id
|
||||
HAVING COUNT(*) > 1 AND MAX(player_count) > 0
|
||||
HAVING MAX(player_count) > 0
|
||||
ORDER BY date DESC
|
||||
LIMIT 10
|
||||
`
|
||||
|
||||
@@ -77,8 +77,8 @@ func NewConfigService(repository *repository.ConfigRepository, serverRepository
|
||||
repository: repository,
|
||||
serverRepository: serverRepository,
|
||||
configCache: model.NewServerConfigCache(model.CacheConfig{
|
||||
ExpirationTime: 5 * time.Minute, // Cache configs for 5 minutes
|
||||
ThrottleTime: 1 * time.Second, // Prevent rapid re-reads
|
||||
ExpirationTime: 5 * time.Minute,
|
||||
ThrottleTime: 1 * time.Second,
|
||||
DefaultStatus: model.StatusUnknown,
|
||||
}),
|
||||
}
|
||||
@@ -88,10 +88,6 @@ func (as *ConfigService) SetServerService(serverService *ServerService) {
|
||||
as.serverService = serverService
|
||||
}
|
||||
|
||||
// UpdateConfig
|
||||
// Updates physical config file and caches it in database.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -103,7 +99,62 @@ func (as *ConfigService) UpdateConfig(ctx *fiber.Ctx, body *map[string]interface
|
||||
return as.updateConfigInternal(ctx.UserContext(), serverID, configFile, body, override)
|
||||
}
|
||||
|
||||
// updateConfigInternal handles the actual config update logic without Fiber dependencies
|
||||
func (as *ConfigService) updateConfigFiles(ctx context.Context, server *model.Server, configFile string, body *map[string]interface{}, override bool) ([]byte, []byte, error) {
|
||||
if server == nil {
|
||||
logging.Error("Server not found")
|
||||
return nil, nil, fmt.Errorf("server not found")
|
||||
}
|
||||
|
||||
configPath := filepath.Join(server.GetConfigPath(), configFile)
|
||||
oldData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
oldData = []byte("{}")
|
||||
} else {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
oldDataUTF8, err := DecodeUTF16LEBOM(oldData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
newData, err := json.Marshal(&body)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !override {
|
||||
newData, err = jsons.Merge(oldDataUTF8, newData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
newData, err = common.IndentJson(newData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
newDataUTF16, err := EncodeUTF16LEBOM(newData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, newDataUTF16, 0644); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return oldDataUTF8, newData, nil
|
||||
}
|
||||
|
||||
func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID string, configFile string, body *map[string]interface{}, override bool) (*model.Config, error) {
|
||||
serverUUID, err := uuid.Parse(serverID)
|
||||
if err != nil {
|
||||
@@ -117,63 +168,14 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
|
||||
return nil, fmt.Errorf("server not found")
|
||||
}
|
||||
|
||||
// Read existing config
|
||||
configPath := filepath.Join(server.GetConfigPath(), configFile)
|
||||
oldData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Create directory if it doesn't exist
|
||||
dir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create empty JSON file
|
||||
if err := os.WriteFile(configPath, []byte("{}"), 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldData = []byte("{}")
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
oldDataUTF8, err := DecodeUTF16LEBOM(oldData)
|
||||
oldDataUTF8, newData, err := as.updateConfigFiles(ctx, server, configFile, body, override)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write new config
|
||||
newData, err := json.Marshal(&body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !override {
|
||||
newData, err = jsons.Merge(oldDataUTF8, newData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
newData, err = common.IndentJson(newData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newDataUTF16, err := EncodeUTF16LEBOM(newData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(configPath, newDataUTF16, 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invalidate all configs for this server since configs can be interdependent
|
||||
as.configCache.InvalidateServerCache(serverID)
|
||||
|
||||
as.serverService.StartAccServerRuntime(server)
|
||||
|
||||
// Log change
|
||||
return as.repository.UpdateConfig(ctx, &model.Config{
|
||||
ServerID: serverUUID,
|
||||
ConfigFile: configFile,
|
||||
@@ -183,10 +185,6 @@ func (as *ConfigService) updateConfigInternal(ctx context.Context, serverID stri
|
||||
}), nil
|
||||
}
|
||||
|
||||
// GetConfig
|
||||
// Gets physical config file and caches it in database.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -197,44 +195,47 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
logging.Debug("Getting config for server ID: %s, file: %s", serverIDStr, configFile)
|
||||
|
||||
server, err := as.serverRepository.GetByID(ctx.UserContext(), serverIDStr)
|
||||
|
||||
if err != nil {
|
||||
logging.Error("Server not found")
|
||||
return nil, fiber.NewError(404, "Server not found")
|
||||
}
|
||||
return as.getConfigFile(server, configFile)
|
||||
}
|
||||
|
||||
// Try to get from cache based on config file type
|
||||
func (as *ConfigService) getConfigFile(server *model.Server, configFile string) (interface{}, error) {
|
||||
serverIDStr := server.ID.String()
|
||||
switch configFile {
|
||||
case ConfigurationJson:
|
||||
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
|
||||
logging.Debug("Returning cached configuration for server ID: %s", serverIDStr)
|
||||
return cached, nil
|
||||
return *cached, nil
|
||||
}
|
||||
case AssistRulesJson:
|
||||
if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok {
|
||||
logging.Debug("Returning cached assist rules for server ID: %s", serverIDStr)
|
||||
return cached, nil
|
||||
return *cached, nil
|
||||
}
|
||||
case EventJson:
|
||||
if cached, ok := as.configCache.GetEvent(serverIDStr); ok {
|
||||
logging.Debug("Returning cached event config for server ID: %s", serverIDStr)
|
||||
return cached, nil
|
||||
return *cached, nil
|
||||
}
|
||||
case EventRulesJson:
|
||||
if cached, ok := as.configCache.GetEventRules(serverIDStr); ok {
|
||||
logging.Debug("Returning cached event rules for server ID: %s", serverIDStr)
|
||||
return cached, nil
|
||||
return *cached, nil
|
||||
}
|
||||
case SettingsJson:
|
||||
if cached, ok := as.configCache.GetSettings(serverIDStr); ok {
|
||||
logging.Debug("Returning cached settings for server ID: %s", serverIDStr)
|
||||
return cached, nil
|
||||
return *cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
logging.Debug("Cache miss for server ID: %s, file: %s - loading from disk", serverIDStr, configFile)
|
||||
|
||||
// Not in cache, load from disk
|
||||
configPath := filepath.Join(server.GetConfigPath(), configFile)
|
||||
configPath := server.GetConfigPath()
|
||||
decoder := DecodeFileName(configFile)
|
||||
if decoder == nil {
|
||||
return nil, errors.New("invalid config file")
|
||||
@@ -244,43 +245,39 @@ func (as *ConfigService) GetConfig(ctx *fiber.Ctx) (interface{}, error) {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
logging.Debug("Config file not found, creating default for server ID: %s, file: %s", serverIDStr, configFile)
|
||||
// Return empty config if file doesn't exist
|
||||
switch configFile {
|
||||
case ConfigurationJson:
|
||||
return &model.Configuration{}, nil
|
||||
return model.Configuration{}, nil
|
||||
case AssistRulesJson:
|
||||
return &model.AssistRules{}, nil
|
||||
return model.AssistRules{}, nil
|
||||
case EventJson:
|
||||
return &model.EventConfig{}, nil
|
||||
return model.EventConfig{}, nil
|
||||
case EventRulesJson:
|
||||
return &model.EventRules{}, nil
|
||||
return model.EventRules{}, nil
|
||||
case SettingsJson:
|
||||
return &model.ServerSettings{}, nil
|
||||
return model.ServerSettings{}, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the loaded config
|
||||
switch configFile {
|
||||
case ConfigurationJson:
|
||||
as.configCache.UpdateConfiguration(serverIDStr, *config.(*model.Configuration))
|
||||
as.configCache.UpdateConfiguration(serverIDStr, config.(model.Configuration))
|
||||
case AssistRulesJson:
|
||||
as.configCache.UpdateAssistRules(serverIDStr, *config.(*model.AssistRules))
|
||||
as.configCache.UpdateAssistRules(serverIDStr, config.(model.AssistRules))
|
||||
case EventJson:
|
||||
as.configCache.UpdateEvent(serverIDStr, *config.(*model.EventConfig))
|
||||
as.configCache.UpdateEvent(serverIDStr, config.(model.EventConfig))
|
||||
case EventRulesJson:
|
||||
as.configCache.UpdateEventRules(serverIDStr, *config.(*model.EventRules))
|
||||
as.configCache.UpdateEventRules(serverIDStr, config.(model.EventRules))
|
||||
case SettingsJson:
|
||||
as.configCache.UpdateSettings(serverIDStr, *config.(*model.ServerSettings))
|
||||
as.configCache.UpdateSettings(serverIDStr, config.(model.ServerSettings))
|
||||
}
|
||||
|
||||
logging.Debug("Successfully loaded and cached config for server ID: %s, file: %s", serverIDStr, configFile)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetConfigs
|
||||
// Gets all configurations for a server, using cache when possible.
|
||||
func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, error) {
|
||||
serverID := ctx.Params("id")
|
||||
|
||||
@@ -296,82 +293,33 @@ func (as *ConfigService) GetConfigs(ctx *fiber.Ctx) (*model.Configurations, erro
|
||||
func (as *ConfigService) LoadConfigs(server *model.Server) (*model.Configurations, error) {
|
||||
serverIDStr := server.ID.String()
|
||||
logging.Info("Loading configs for server ID: %s at path: %s", serverIDStr, server.GetConfigPath())
|
||||
configs := &model.Configurations{}
|
||||
|
||||
// Load configuration
|
||||
if cached, ok := as.configCache.GetConfiguration(serverIDStr); ok {
|
||||
logging.Debug("Using cached configuration for server %s", serverIDStr)
|
||||
configs.Configuration = *cached
|
||||
} else {
|
||||
logging.Debug("Loading configuration from disk for server %s", serverIDStr)
|
||||
config, err := mustDecode[model.Configuration](ConfigurationJson, server.GetConfigPath())
|
||||
if err != nil {
|
||||
logging.Error("Failed to load configuration for server %s: %v", serverIDStr, err)
|
||||
return nil, fmt.Errorf("failed to load configuration: %v", err)
|
||||
}
|
||||
configs.Configuration = config
|
||||
as.configCache.UpdateConfiguration(serverIDStr, config)
|
||||
settingsConf, err := as.getConfigFile(server, SettingsJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load assist rules
|
||||
if cached, ok := as.configCache.GetAssistRules(serverIDStr); ok {
|
||||
logging.Debug("Using cached assist rules for server %s", serverIDStr)
|
||||
configs.AssistRules = *cached
|
||||
} else {
|
||||
logging.Debug("Loading assist rules from disk for server %s", serverIDStr)
|
||||
rules, err := mustDecode[model.AssistRules](AssistRulesJson, server.GetConfigPath())
|
||||
if err != nil {
|
||||
logging.Error("Failed to load assist rules for server %s: %v", serverIDStr, err)
|
||||
return nil, fmt.Errorf("failed to load assist rules: %v", err)
|
||||
}
|
||||
configs.AssistRules = rules
|
||||
as.configCache.UpdateAssistRules(serverIDStr, rules)
|
||||
eventRulesConf, err := as.getConfigFile(server, EventRulesJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load event config
|
||||
if cached, ok := as.configCache.GetEvent(serverIDStr); ok {
|
||||
logging.Debug("Using cached event config for server %s", serverIDStr)
|
||||
configs.Event = *cached
|
||||
} else {
|
||||
logging.Debug("Loading event config from disk for server %s", serverIDStr)
|
||||
event, err := mustDecode[model.EventConfig](EventJson, server.GetConfigPath())
|
||||
if err != nil {
|
||||
logging.Error("Failed to load event config for server %s: %v", serverIDStr, err)
|
||||
return nil, fmt.Errorf("failed to load event config: %v", err)
|
||||
}
|
||||
configs.Event = event
|
||||
logging.Debug("Updating event config for server %s with track: %s", serverIDStr, event.Track)
|
||||
as.configCache.UpdateEvent(serverIDStr, event)
|
||||
eventConf, err := as.getConfigFile(server, EventJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load event rules
|
||||
if cached, ok := as.configCache.GetEventRules(serverIDStr); ok {
|
||||
logging.Debug("Using cached event rules for server %s", serverIDStr)
|
||||
configs.EventRules = *cached
|
||||
} else {
|
||||
logging.Debug("Loading event rules from disk for server %s", serverIDStr)
|
||||
rules, err := mustDecode[model.EventRules](EventRulesJson, server.GetConfigPath())
|
||||
if err != nil {
|
||||
logging.Error("Failed to load event rules for server %s: %v", serverIDStr, err)
|
||||
return nil, fmt.Errorf("failed to load event rules: %v", err)
|
||||
}
|
||||
configs.EventRules = rules
|
||||
as.configCache.UpdateEventRules(serverIDStr, rules)
|
||||
assistRulesConf, err := as.getConfigFile(server, AssistRulesJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load settings
|
||||
if cached, ok := as.configCache.GetSettings(serverIDStr); ok {
|
||||
logging.Debug("Using cached settings for server %s", serverIDStr)
|
||||
configs.Settings = *cached
|
||||
} else {
|
||||
logging.Debug("Loading settings from disk for server %s", serverIDStr)
|
||||
settings, err := mustDecode[model.ServerSettings](SettingsJson, server.GetConfigPath())
|
||||
if err != nil {
|
||||
logging.Error("Failed to load settings for server %s: %v", serverIDStr, err)
|
||||
return nil, fmt.Errorf("failed to load settings: %v", err)
|
||||
}
|
||||
configs.Settings = settings
|
||||
as.configCache.UpdateSettings(serverIDStr, settings)
|
||||
configurationConf, err := as.getConfigFile(server, ConfigurationJson)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configs := &model.Configurations{
|
||||
Settings: settingsConf.(model.ServerSettings),
|
||||
EventRules: eventRulesConf.(model.EventRules),
|
||||
Event: eventConf.(model.EventConfig),
|
||||
AssistRules: assistRulesConf.(model.AssistRules),
|
||||
Configuration: configurationConf.(model.Configuration),
|
||||
}
|
||||
|
||||
logging.Info("Successfully loaded all configs for server %s", serverIDStr)
|
||||
@@ -396,9 +344,6 @@ func readFile(path string, configFile string) ([]byte, error) {
|
||||
configPath := filepath.Join(path, configFile)
|
||||
oldData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("config file %s does not exist at %s", configFile, configPath)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return oldData, nil
|
||||
@@ -475,9 +420,7 @@ func (as *ConfigService) GetConfiguration(server *model.Server) (*model.Configur
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// SaveConfiguration saves the configuration for a server
|
||||
func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.Configuration) error {
|
||||
// Convert config to map for UpdateConfig
|
||||
configMap := make(map[string]interface{})
|
||||
configBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
@@ -487,7 +430,6 @@ func (as *ConfigService) SaveConfiguration(server *model.Server, config *model.C
|
||||
return fmt.Errorf("failed to unmarshal configuration: %v", err)
|
||||
}
|
||||
|
||||
// Update the configuration using the internal method
|
||||
_, err = as.updateConfigInternal(context.Background(), server.ID.String(), ConfigurationJson, &configMap, true)
|
||||
_, _, err = as.updateConfigFiles(context.Background(), server, ConfigurationJson, &configMap, true)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -96,11 +96,9 @@ func (s *FirewallService) DeleteServerRules(serverName string, tcpPorts, udpPort
|
||||
}
|
||||
|
||||
func (s *FirewallService) UpdateServerRules(serverName string, tcpPorts, udpPorts []int) error {
|
||||
// First delete existing rules
|
||||
if err := s.DeleteServerRules(serverName, tcpPorts, udpPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then create new rules
|
||||
return s.CreateServerRules(serverName, tcpPorts, udpPorts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,47 +12,57 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CacheInvalidator interface for cache invalidation
|
||||
type CacheInvalidator interface {
|
||||
InvalidateUserPermissions(userID string)
|
||||
InvalidateAllUserPermissions()
|
||||
}
|
||||
|
||||
// MembershipService provides business logic for membership-related operations.
|
||||
type MembershipService struct {
|
||||
repo *repository.MembershipRepository
|
||||
cacheInvalidator CacheInvalidator
|
||||
jwtHandler *jwt.JWTHandler
|
||||
openJwtHandler *jwt.OpenJWTHandler
|
||||
}
|
||||
|
||||
// NewMembershipService creates a new MembershipService.
|
||||
func NewMembershipService(repo *repository.MembershipRepository) *MembershipService {
|
||||
func NewMembershipService(repo *repository.MembershipRepository, jwtHandler *jwt.JWTHandler, openJwtHandler *jwt.OpenJWTHandler) *MembershipService {
|
||||
return &MembershipService{
|
||||
repo: repo,
|
||||
cacheInvalidator: nil, // Will be set later via SetCacheInvalidator
|
||||
cacheInvalidator: nil,
|
||||
jwtHandler: jwtHandler,
|
||||
openJwtHandler: openJwtHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCacheInvalidator sets the cache invalidator after service initialization
|
||||
func (s *MembershipService) SetCacheInvalidator(invalidator CacheInvalidator) {
|
||||
s.cacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT.
|
||||
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
func (s *MembershipService) HandleLogin(ctx context.Context, username, password string) (*model.User, error) {
|
||||
user, err := s.repo.FindUserByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid credentials")
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
// Use secure password verification with constant-time comparison
|
||||
if err := user.VerifyPassword(password); err != nil {
|
||||
return "", errors.New("invalid credentials")
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
|
||||
return jwt.GenerateToken(user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *MembershipService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
user, err := s.HandleLogin(ctx, username, password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return s.jwtHandler.GenerateToken(user.ID.String())
|
||||
}
|
||||
|
||||
func (s *MembershipService) GenerateOpenToken(ctx context.Context, userId string) (string, error) {
|
||||
return s.openJwtHandler.GenerateToken(userId)
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
func (s *MembershipService) CreateUser(ctx context.Context, username, password, roleName string) (*model.User, error) {
|
||||
|
||||
role, err := s.repo.FindRoleByName(ctx, roleName)
|
||||
@@ -76,43 +86,35 @@ func (s *MembershipService) CreateUser(ctx context.Context, username, password,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ListUsers retrieves all users.
|
||||
func (s *MembershipService) ListUsers(ctx context.Context) ([]*model.User, error) {
|
||||
return s.repo.ListUsers(ctx)
|
||||
}
|
||||
|
||||
// GetUser retrieves a single user by ID.
|
||||
func (s *MembershipService) GetUser(ctx context.Context, userID uuid.UUID) (*model.User, error) {
|
||||
return s.repo.FindUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserWithPermissions retrieves a single user by ID with their role and permissions.
|
||||
func (s *MembershipService) GetUserWithPermissions(ctx context.Context, userID string) (*model.User, error) {
|
||||
return s.repo.FindUserByIDWithPermissions(ctx, userID)
|
||||
}
|
||||
|
||||
// UpdateUserRequest defines the request body for updating a user.
|
||||
type UpdateUserRequest struct {
|
||||
Username *string `json:"username"`
|
||||
Password *string `json:"password"`
|
||||
RoleID *uuid.UUID `json:"roleId"`
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user with validation to prevent Super Admin deletion.
|
||||
func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) error {
|
||||
// Get user with role information
|
||||
user, err := s.repo.FindUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
// Get role to check if it's Super Admin
|
||||
role, err := s.repo.FindRoleByID(ctx, user.RoleID)
|
||||
if err != nil {
|
||||
return errors.New("user role not found")
|
||||
}
|
||||
|
||||
// Prevent deletion of Super Admin users
|
||||
if role.Name == "Super Admin" {
|
||||
return errors.New("cannot delete Super Admin user")
|
||||
}
|
||||
@@ -122,7 +124,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate cache for deleted user
|
||||
if s.cacheInvalidator != nil {
|
||||
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
|
||||
}
|
||||
@@ -131,7 +132,6 @@ func (s *MembershipService) DeleteUser(ctx context.Context, userID uuid.UUID) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's details.
|
||||
func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, req UpdateUserRequest) (*model.User, error) {
|
||||
user, err := s.repo.FindUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -143,12 +143,10 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
|
||||
}
|
||||
|
||||
if req.Password != nil && *req.Password != "" {
|
||||
// Password will be automatically hashed in BeforeUpdate hook
|
||||
user.Password = *req.Password
|
||||
}
|
||||
|
||||
if req.RoleID != nil {
|
||||
// Check if role exists
|
||||
_, err := s.repo.FindRoleByID(ctx, *req.RoleID)
|
||||
if err != nil {
|
||||
return nil, errors.New("role not found")
|
||||
@@ -160,7 +158,6 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invalidate cache if role was changed
|
||||
if req.RoleID != nil && s.cacheInvalidator != nil {
|
||||
s.cacheInvalidator.InvalidateUserPermissions(userID.String())
|
||||
}
|
||||
@@ -169,14 +166,12 @@ func (s *MembershipService) UpdateUser(ctx context.Context, userID uuid.UUID, re
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// HasPermission checks if a user has a specific permission.
|
||||
func (s *MembershipService) HasPermission(ctx context.Context, userID string, permissionName string) (bool, error) {
|
||||
user, err := s.repo.FindUserByIDWithPermissions(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Super admin and Admin have all permissions
|
||||
if user.Role.Name == "Super Admin" || user.Role.Name == "Admin" {
|
||||
return true, nil
|
||||
}
|
||||
@@ -190,15 +185,13 @@ func (s *MembershipService) HasPermission(ctx context.Context, userID string, pe
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// SetupInitialData creates the initial roles and permissions.
|
||||
func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
// Define all permissions
|
||||
permissions := model.AllPermissions()
|
||||
|
||||
createdPermissions := make([]model.Permission, 0)
|
||||
for _, pName := range permissions {
|
||||
perm, err := s.repo.FindPermissionByName(ctx, pName)
|
||||
if err != nil { // Assuming error means not found
|
||||
if err != nil {
|
||||
perm = &model.Permission{Name: pName}
|
||||
if err := s.repo.CreatePermission(ctx, perm); err != nil {
|
||||
return err
|
||||
@@ -207,7 +200,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
createdPermissions = append(createdPermissions, *perm)
|
||||
}
|
||||
|
||||
// Create Super Admin role with all permissions
|
||||
superAdminRole, err := s.repo.FindRoleByName(ctx, "Super Admin")
|
||||
if err != nil {
|
||||
superAdminRole = &model.Role{Name: "Super Admin"}
|
||||
@@ -219,7 +211,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create Admin role with same permissions as Super Admin
|
||||
adminRole, err := s.repo.FindRoleByName(ctx, "Admin")
|
||||
if err != nil {
|
||||
adminRole = &model.Role{Name: "Admin"}
|
||||
@@ -231,7 +222,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create Manager role with limited permissions (excluding membership, role, user, server create/delete)
|
||||
managerRole, err := s.repo.FindRoleByName(ctx, "Manager")
|
||||
if err != nil {
|
||||
managerRole = &model.Role{Name: "Manager"}
|
||||
@@ -240,7 +230,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Define manager permissions (limited set)
|
||||
managerPermissionNames := []string{
|
||||
model.ServerView,
|
||||
model.ServerUpdate,
|
||||
@@ -264,16 +253,14 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate all caches after role setup changes
|
||||
if s.cacheInvalidator != nil {
|
||||
s.cacheInvalidator.InvalidateAllUserPermissions()
|
||||
}
|
||||
|
||||
// Create a default admin user if one doesn't exist
|
||||
_, err = s.repo.FindUserByUsername(ctx, "admin")
|
||||
if err != nil {
|
||||
logging.Debug("Creating default admin user")
|
||||
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin") // Default password, should be changed
|
||||
_, err = s.CreateUser(ctx, "admin", os.Getenv("PASSWORD"), "Super Admin")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -282,7 +269,6 @@ func (s *MembershipService) SetupInitialData(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllRoles retrieves all roles for dropdown selection.
|
||||
func (s *MembershipService) GetAllRoles(ctx context.Context) ([]*model.Role, error) {
|
||||
return s.repo.ListRoles(ctx)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
|
||||
const (
|
||||
DefaultStartPort = 9600
|
||||
RequiredPortCount = 1 // Update this if ACC needs more ports
|
||||
RequiredPortCount = 1
|
||||
)
|
||||
|
||||
type ServerService struct {
|
||||
@@ -31,6 +31,7 @@ type ServerService struct {
|
||||
steamService *SteamService
|
||||
windowsService *WindowsService
|
||||
firewallService *FirewallService
|
||||
webSocketService *WebSocketService
|
||||
instances sync.Map // Track instances per server
|
||||
lastInsertTimes sync.Map // Track last insert time per server
|
||||
debouncers sync.Map // Track debounce timers per server
|
||||
@@ -44,18 +45,15 @@ type pendingState struct {
|
||||
}
|
||||
|
||||
func (s *ServerService) ensureLogTailing(server *model.Server, instance *tracking.AccServerInstance) {
|
||||
// Check if we already have a tailer
|
||||
if _, exists := s.logTailers.Load(server.ID); exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Start tailing in a goroutine that handles file creation/deletion
|
||||
go func() {
|
||||
logPath := filepath.Join(server.GetLogPath(), "server.log")
|
||||
tailer := tracking.NewLogTailer(logPath, instance.HandleLogLine)
|
||||
s.logTailers.Store(server.ID, tailer)
|
||||
|
||||
// Start tailing and automatically handle file changes
|
||||
tailer.Start()
|
||||
}()
|
||||
}
|
||||
@@ -68,6 +66,7 @@ func NewServerService(
|
||||
steamService *SteamService,
|
||||
windowsService *WindowsService,
|
||||
firewallService *FirewallService,
|
||||
webSocketService *WebSocketService,
|
||||
) *ServerService {
|
||||
service := &ServerService{
|
||||
repository: repository,
|
||||
@@ -77,9 +76,9 @@ func NewServerService(
|
||||
steamService: steamService,
|
||||
windowsService: windowsService,
|
||||
firewallService: firewallService,
|
||||
webSocketService: webSocketService,
|
||||
}
|
||||
|
||||
// Initialize server instances
|
||||
servers, err := repository.GetAll(context.Background(), &model.ServerFilter{})
|
||||
if err != nil {
|
||||
logging.Error("Failed to get servers: %v", err)
|
||||
@@ -87,7 +86,6 @@ func NewServerService(
|
||||
}
|
||||
|
||||
for i := range *servers {
|
||||
// Initialize instance regardless of status
|
||||
logging.Info("Starting server runtime for server ID: %d", (*servers)[i].ID)
|
||||
service.StartAccServerRuntime(&(*servers)[i])
|
||||
}
|
||||
@@ -96,7 +94,7 @@ func NewServerService(
|
||||
}
|
||||
|
||||
func (s *ServerService) shouldInsertStateHistory(serverID uuid.UUID) bool {
|
||||
insertInterval := 5 * time.Minute // Configure this as needed
|
||||
insertInterval := 5 * time.Minute
|
||||
|
||||
lastInsertInterface, exists := s.lastInsertTimes.Load(serverID)
|
||||
if !exists {
|
||||
@@ -119,16 +117,15 @@ func (s *ServerService) getNextSessionID(serverID uuid.UUID) uuid.UUID {
|
||||
lastID, err := s.stateHistoryRepo.GetLastSessionID(context.Background(), serverID)
|
||||
if err != nil {
|
||||
logging.Error("Failed to get last session ID for server %s: %v", serverID, err)
|
||||
return uuid.New() // Return new UUID as fallback
|
||||
return uuid.New()
|
||||
}
|
||||
if lastID == uuid.Nil {
|
||||
return uuid.New() // Return new UUID if no previous session
|
||||
return uuid.New()
|
||||
}
|
||||
return uuid.New() // Always generate new UUID for each session
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.ServerState) {
|
||||
// Get or create session ID when session changes
|
||||
currentSessionInterface, exists := s.instances.Load(serverID)
|
||||
var sessionID uuid.UUID
|
||||
if !exists {
|
||||
@@ -159,8 +156,7 @@ func (s *ServerService) insertStateHistory(serverID uuid.UUID, state *model.Serv
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerService) updateSessionDuration(server *model.Server, sessionType string) {
|
||||
// Get configs using helper methods
|
||||
func (s *ServerService) updateSessionDuration(server *model.Server, sessionType model.TrackSession) {
|
||||
event, err := s.configService.GetEventConfig(server)
|
||||
if err != nil {
|
||||
event = &model.EventConfig{}
|
||||
@@ -178,9 +174,7 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
|
||||
serverInstance.State.Track = event.Track
|
||||
serverInstance.State.MaxConnections = configuration.MaxConnections.ToInt()
|
||||
|
||||
// Check if session type has changed
|
||||
if serverInstance.State.Session != sessionType {
|
||||
// Get new session ID for the new session
|
||||
sessionID := s.getNextSessionID(server.ID)
|
||||
s.sessionIDs.Store(server.ID, sessionID)
|
||||
}
|
||||
@@ -201,32 +195,27 @@ func (s *ServerService) updateSessionDuration(server *model.Server, sessionType
|
||||
}
|
||||
|
||||
func (s *ServerService) GenerateServerPath(server *model.Server) {
|
||||
// Get the base steamcmd path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDDirPath()
|
||||
server.Path = server.GenerateServerPath(steamCMDPath)
|
||||
server.FromSteamCMD = true
|
||||
}
|
||||
|
||||
func (s *ServerService) handleStateChange(server *model.Server, state *model.ServerState) {
|
||||
// Update session duration when session changes
|
||||
s.updateSessionDuration(server, state.Session)
|
||||
|
||||
// Invalidate status cache when server state changes
|
||||
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
|
||||
|
||||
// Cancel existing timer if any
|
||||
if debouncer, exists := s.debouncers.Load(server.ID); exists {
|
||||
pending := debouncer.(*pendingState)
|
||||
pending.timer.Stop()
|
||||
}
|
||||
|
||||
// Create new timer
|
||||
timer := time.NewTimer(5 * time.Minute)
|
||||
s.debouncers.Store(server.ID, &pendingState{
|
||||
timer: timer,
|
||||
state: state,
|
||||
})
|
||||
|
||||
// Start goroutine to handle the delayed insert
|
||||
go func() {
|
||||
<-timer.C
|
||||
if debouncer, exists := s.debouncers.Load(server.ID); exists {
|
||||
@@ -236,14 +225,12 @@ func (s *ServerService) handleStateChange(server *model.Server, state *model.Ser
|
||||
}
|
||||
}()
|
||||
|
||||
// If enough time has passed since last insert, insert immediately
|
||||
if s.shouldInsertStateHistory(server.ID) {
|
||||
s.insertStateHistory(server.ID, state)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) StartAccServerRuntime(server *model.Server) {
|
||||
// Get or create instance
|
||||
instanceInterface, exists := s.instances.Load(server.ID)
|
||||
var instance *tracking.AccServerInstance
|
||||
if !exists {
|
||||
@@ -255,20 +242,14 @@ func (s *ServerService) StartAccServerRuntime(server *model.Server) {
|
||||
instance = instanceInterface.(*tracking.AccServerInstance)
|
||||
}
|
||||
|
||||
// Invalidate config cache for this server before loading new configs
|
||||
serverIDStr := server.ID.String()
|
||||
s.configService.configCache.InvalidateServerCache(serverIDStr)
|
||||
|
||||
s.updateSessionDuration(server, instance.State.Session)
|
||||
|
||||
// Ensure log tailing is running (regardless of server status)
|
||||
s.ensureLogTailing(server, instance)
|
||||
}
|
||||
|
||||
// GetAll
|
||||
// Gets All rows from Server table.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -300,10 +281,6 @@ func (s *ServerService) GetAll(ctx *fiber.Ctx, filter *model.ServerFilter) (*[]m
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// GetById
|
||||
// Gets rows by ID from Server table.
|
||||
//
|
||||
// Args:
|
||||
// context.Context: Application context
|
||||
// Returns:
|
||||
// string: Application version
|
||||
@@ -330,76 +307,188 @@ func (as *ServerService) GetById(ctx *fiber.Ctx, serverID uuid.UUID) (*model.Ser
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *ServerService) CreateServer(ctx *fiber.Ctx, server *model.Server) error {
|
||||
// Validate basic server configuration
|
||||
func (s *ServerService) CreateServerAsync(ctx *fiber.Ctx, server *model.Server) error {
|
||||
logging.Info("create server start")
|
||||
if err := server.Validate(); err != nil {
|
||||
logging.Info("create server validation failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// Install server using SteamCMD
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.GetServerPath()); err != nil {
|
||||
return fmt.Errorf("failed to install server: %v", err)
|
||||
}
|
||||
s.GenerateServerPath(server)
|
||||
|
||||
// Create Windows service with correct paths
|
||||
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
|
||||
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
|
||||
if err := s.windowsService.CreateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
|
||||
// Cleanup on failure
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to create Windows service: %v", err)
|
||||
}
|
||||
bgCtx := context.Background()
|
||||
|
||||
s.configureFirewall(server)
|
||||
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find available ports: %v", err)
|
||||
}
|
||||
|
||||
// Use the first port for both TCP and UDP
|
||||
serverPort := ports[0]
|
||||
tcpPorts := []int{serverPort}
|
||||
udpPorts := []int{serverPort}
|
||||
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
|
||||
// Cleanup on failure
|
||||
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to create firewall rules: %v", err)
|
||||
}
|
||||
|
||||
// Update server configuration with the allocated port
|
||||
if err := s.updateServerPort(server, serverPort); err != nil {
|
||||
return fmt.Errorf("failed to update server configuration: %v", err)
|
||||
}
|
||||
|
||||
// Insert server into database
|
||||
if err := s.repository.Insert(ctx.UserContext(), server); err != nil {
|
||||
// Cleanup on failure
|
||||
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
|
||||
s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName)
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
return fmt.Errorf("failed to insert server into database: %v", err)
|
||||
}
|
||||
|
||||
// Initialize server runtime
|
||||
s.StartAccServerRuntime(server)
|
||||
go func() {
|
||||
logging.Info("create server start background")
|
||||
if err := s.createServerBackground(bgCtx, server); err != nil {
|
||||
logging.Error("Async server creation failed for server %s: %v", server.ID, err)
|
||||
s.webSocketService.BroadcastError(server.ID, "Server creation failed", err.Error())
|
||||
s.webSocketService.BroadcastComplete(server.ID, false, fmt.Sprintf("Server creation failed: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type createServerStep struct {
|
||||
stepType model.ServerCreationStep
|
||||
important bool
|
||||
callback func() (string, error)
|
||||
description string
|
||||
}
|
||||
|
||||
func (s *ServerService) createServerBackground(ctx context.Context, server *model.Server) error {
|
||||
var serverPort int
|
||||
var tcpPorts, udpPorts []int
|
||||
|
||||
steps := []createServerStep{
|
||||
{
|
||||
stepType: model.StepValidation,
|
||||
important: true,
|
||||
description: "Server configuration validated successfully",
|
||||
callback: func() (string, error) {
|
||||
if err := server.Validate(); err != nil {
|
||||
return "", fmt.Errorf("validation failed: %v", err)
|
||||
}
|
||||
return "Server configuration validated successfully", nil
|
||||
},
|
||||
},
|
||||
{
|
||||
stepType: model.StepDirectoryCreation,
|
||||
important: true,
|
||||
description: "Server directories prepared",
|
||||
callback: func() (string, error) {
|
||||
return "Server directories prepared", nil
|
||||
},
|
||||
},
|
||||
{
|
||||
stepType: model.StepSteamDownload,
|
||||
important: true,
|
||||
description: "Server files downloaded successfully",
|
||||
callback: func() (string, error) {
|
||||
if err := s.steamService.InstallServerWithWebSocket(ctx, server.Path, &server.ID, s.webSocketService); err != nil {
|
||||
return "", fmt.Errorf("failed to install server: %v", err)
|
||||
}
|
||||
return "Server files downloaded successfully", nil
|
||||
},
|
||||
},
|
||||
{
|
||||
stepType: model.StepConfigGeneration,
|
||||
important: true,
|
||||
description: "",
|
||||
callback: func() (string, error) {
|
||||
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to find available ports: %v", err)
|
||||
}
|
||||
|
||||
serverPort = ports[0]
|
||||
|
||||
if err := s.updateServerPort(server, serverPort); err != nil {
|
||||
return "", fmt.Errorf("failed to update server configuration: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Server configuration generated (Port: %d)", serverPort), nil
|
||||
},
|
||||
},
|
||||
{
|
||||
stepType: model.StepServiceCreation,
|
||||
important: true,
|
||||
description: "",
|
||||
callback: func() (string, error) {
|
||||
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
|
||||
serverWorkingDir := filepath.Join(server.GetServerPath(), "server")
|
||||
if err := s.windowsService.CreateService(ctx, server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
|
||||
return "", fmt.Errorf("failed to create Windows service: %v", err)
|
||||
}
|
||||
return fmt.Sprintf("Windows service '%s' created successfully", server.ServiceName), nil
|
||||
},
|
||||
},
|
||||
{
|
||||
stepType: model.StepFirewallRules,
|
||||
important: false,
|
||||
description: "",
|
||||
callback: func() (string, error) {
|
||||
s.configureFirewall(server)
|
||||
tcpPorts = []int{serverPort}
|
||||
udpPorts = []int{serverPort}
|
||||
if err := s.firewallService.CreateServerRules(server.ServiceName, tcpPorts, udpPorts); err != nil {
|
||||
return "", fmt.Errorf("failed to create firewall rules: %v", err)
|
||||
}
|
||||
return fmt.Sprintf("Firewall rules created for port %d", serverPort), nil
|
||||
},
|
||||
},
|
||||
{
|
||||
stepType: model.StepDatabaseSave,
|
||||
important: true,
|
||||
description: "Server saved to database successfully",
|
||||
callback: func() (string, error) {
|
||||
if err := s.repository.Insert(ctx, server); err != nil {
|
||||
return "", fmt.Errorf("failed to insert server into database: %v", err)
|
||||
}
|
||||
return "Server saved to database successfully", nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, step := range steps {
|
||||
s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusInProgress,
|
||||
model.GetStepDescription(step.stepType), "")
|
||||
|
||||
successMessage, err := step.callback()
|
||||
if err != nil {
|
||||
s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusFailed,
|
||||
"", err.Error())
|
||||
|
||||
if step.important {
|
||||
s.rollbackSteps(ctx, server, steps[:i], tcpPorts, udpPorts)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.webSocketService.BroadcastStep(server.ID, step.stepType, model.StatusCompleted,
|
||||
successMessage, "")
|
||||
}
|
||||
|
||||
s.StartAccServerRuntime(server)
|
||||
|
||||
s.webSocketService.BroadcastStep(server.ID, model.StepCompleted, model.StatusCompleted,
|
||||
model.GetStepDescription(model.StepCompleted), "")
|
||||
|
||||
s.webSocketService.BroadcastComplete(server.ID, true,
|
||||
fmt.Sprintf("Server '%s' created successfully on port %d", server.Name, serverPort))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerService) rollbackSteps(ctx context.Context, server *model.Server, completedSteps []createServerStep, tcpPorts, udpPorts []int) {
|
||||
for i := len(completedSteps) - 1; i >= 0; i-- {
|
||||
step := completedSteps[i]
|
||||
switch step.stepType {
|
||||
case model.StepDatabaseSave:
|
||||
s.repository.Delete(ctx, server.ID)
|
||||
case model.StepFirewallRules:
|
||||
if len(tcpPorts) > 0 && len(udpPorts) > 0 {
|
||||
s.firewallService.DeleteServerRules(server.ServiceName, tcpPorts, udpPorts)
|
||||
}
|
||||
case model.StepServiceCreation:
|
||||
s.windowsService.DeleteService(ctx, server.ServiceName)
|
||||
case model.StepSteamDownload:
|
||||
s.steamService.UninstallServer(server.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
|
||||
// Get server details
|
||||
server, err := s.repository.GetByID(ctx.UserContext(), serverID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get server details: %v", err)
|
||||
}
|
||||
|
||||
// Stop and remove Windows service
|
||||
if err := s.windowsService.DeleteService(ctx.UserContext(), server.ServiceName); err != nil {
|
||||
logging.Error("Failed to delete Windows service: %v", err)
|
||||
}
|
||||
|
||||
// Remove firewall rules
|
||||
configuration, err := s.configService.GetConfiguration(server)
|
||||
if err != nil {
|
||||
logging.Error("Failed to get configuration for server %d: %v", server.ID, err)
|
||||
@@ -410,17 +499,14 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
|
||||
logging.Error("Failed to delete firewall rules: %v", err)
|
||||
}
|
||||
|
||||
// Uninstall server files
|
||||
if err := s.steamService.UninstallServer(server.Path); err != nil {
|
||||
logging.Error("Failed to uninstall server: %v", err)
|
||||
}
|
||||
|
||||
// Remove from database
|
||||
if err := s.repository.Delete(ctx.UserContext(), serverID); err != nil {
|
||||
return fmt.Errorf("failed to delete server from database: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup runtime resources
|
||||
if tailer, exists := s.logTailers.Load(server.ID); exists {
|
||||
tailer.(*tracking.LogTailer).Stop()
|
||||
s.logTailers.Delete(server.ID)
|
||||
@@ -430,84 +516,27 @@ func (s *ServerService) DeleteServer(ctx *fiber.Ctx, serverID uuid.UUID) error {
|
||||
s.debouncers.Delete(server.ID)
|
||||
s.sessionIDs.Delete(server.ID)
|
||||
|
||||
// Invalidate status cache for deleted server
|
||||
s.apiService.statusCache.InvalidateStatus(server.ServiceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerService) UpdateServer(ctx *fiber.Ctx, server *model.Server) error {
|
||||
// Validate server configuration
|
||||
if err := server.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get existing server details
|
||||
existingServer, err := s.repository.GetByID(ctx.UserContext(), server.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing server details: %v", err)
|
||||
}
|
||||
|
||||
// Update server files if path changed
|
||||
if existingServer.Path != server.Path {
|
||||
if err := s.steamService.InstallServer(ctx.UserContext(), server.Path); err != nil {
|
||||
return fmt.Errorf("failed to install server to new location: %v", err)
|
||||
}
|
||||
// Clean up old installation
|
||||
if err := s.steamService.UninstallServer(existingServer.Path); err != nil {
|
||||
logging.Error("Failed to remove old server installation: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update Windows service if necessary
|
||||
if existingServer.ServiceName != server.ServiceName || existingServer.Path != server.Path {
|
||||
execPath := filepath.Join(server.GetServerPath(), "accServer.exe")
|
||||
serverWorkingDir := server.GetServerPath()
|
||||
if err := s.windowsService.UpdateService(ctx.UserContext(), server.ServiceName, execPath, serverWorkingDir, nil); err != nil {
|
||||
return fmt.Errorf("failed to update Windows service: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update firewall rules if service name changed
|
||||
if existingServer.ServiceName != server.ServiceName {
|
||||
if err := s.configureFirewall(server); err != nil {
|
||||
return fmt.Errorf("failed to update firewall rules: %v", err)
|
||||
}
|
||||
// Invalidate cache for old service name
|
||||
s.apiService.statusCache.InvalidateStatus(existingServer.ServiceName)
|
||||
}
|
||||
|
||||
// Update database record
|
||||
if err := s.repository.Update(ctx.UserContext(), server); err != nil {
|
||||
return fmt.Errorf("failed to update server in database: %v", err)
|
||||
}
|
||||
|
||||
// Restart server runtime
|
||||
s.StartAccServerRuntime(server)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerService) configureFirewall(server *model.Server) error {
|
||||
// Find available ports for the server
|
||||
ports, err := network.FindAvailablePortRange(DefaultStartPort, RequiredPortCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find available ports: %v", err)
|
||||
}
|
||||
|
||||
// Use the first port for both TCP and UDP
|
||||
serverPort := ports[0]
|
||||
tcpPorts := []int{serverPort}
|
||||
udpPorts := []int{serverPort}
|
||||
|
||||
logging.Info("Configuring firewall for server %d with port %d", server.ID, serverPort)
|
||||
|
||||
// Configure firewall rules
|
||||
if err := s.firewallService.UpdateServerRules(server.Name, tcpPorts, udpPorts); err != nil {
|
||||
return fmt.Errorf("failed to configure firewall: %v", err)
|
||||
}
|
||||
|
||||
// Update server configuration with the allocated port
|
||||
if err := s.updateServerPort(server, serverPort); err != nil {
|
||||
return fmt.Errorf("failed to update server configuration: %v", err)
|
||||
}
|
||||
@@ -516,7 +545,6 @@ func (s *ServerService) configureFirewall(server *model.Server) error {
|
||||
}
|
||||
|
||||
func (s *ServerService) updateServerPort(server *model.Server, port int) error {
|
||||
// Load current configuration
|
||||
config, err := s.configService.GetConfiguration(server)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load server configuration: %v", err)
|
||||
@@ -525,7 +553,6 @@ func (s *ServerService) updateServerPort(server *model.Server, port int) error {
|
||||
config.TcpPort = model.IntString(port)
|
||||
config.UdpPort = model.IntString(port)
|
||||
|
||||
// Save the updated configuration
|
||||
if err := s.configService.SaveConfiguration(server, config); err != nil {
|
||||
return fmt.Errorf("failed to save server configuration: %v", err)
|
||||
}
|
||||
|
||||
@@ -7,26 +7,22 @@ import (
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// InitializeServices
|
||||
// Initializes Dependency Injection modules for services
|
||||
//
|
||||
// Args:
|
||||
// *dig.Container: Dig Container
|
||||
// *dig.Container: Dig Container
|
||||
func InitializeServices(c *dig.Container) {
|
||||
logging.Debug("Initializing repositories")
|
||||
repository.InitializeRepositories(c)
|
||||
|
||||
logging.Debug("Registering services")
|
||||
// Provide services
|
||||
c.Provide(NewSteamService)
|
||||
c.Provide(NewServerService)
|
||||
c.Provide(NewStateHistoryService)
|
||||
c.Provide(NewServiceControlService)
|
||||
c.Provide(NewConfigService)
|
||||
c.Provide(NewLookupService)
|
||||
c.Provide(NewSteamService)
|
||||
c.Provide(NewWindowsService)
|
||||
c.Provide(NewFirewallService)
|
||||
c.Provide(NewMembershipService)
|
||||
c.Provide(NewWebSocketService)
|
||||
|
||||
logging.Debug("Initializing service dependencies")
|
||||
err := c.Invoke(func(server *ServerService, api *ServiceControlService, config *ConfigService) {
|
||||
|
||||
@@ -24,9 +24,9 @@ func NewServiceControlService(repository *repository.ServiceControlRepository,
|
||||
repository: repository,
|
||||
serverRepository: serverRepository,
|
||||
statusCache: model.NewServerStatusCache(model.CacheConfig{
|
||||
ExpirationTime: 30 * time.Second, // Cache expires after 30 seconds
|
||||
ThrottleTime: 5 * time.Second, // Minimum 5 seconds between checks
|
||||
DefaultStatus: model.StatusRunning, // Default to running if throttled
|
||||
ExpirationTime: 30 * time.Second,
|
||||
ThrottleTime: 5 * time.Second,
|
||||
DefaultStatus: model.StatusRunning,
|
||||
}),
|
||||
windowsService: NewWindowsService(),
|
||||
}
|
||||
@@ -42,18 +42,15 @@ func (as *ServiceControlService) GetStatus(ctx *fiber.Ctx) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Try to get status from cache
|
||||
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
|
||||
return status.String(), nil
|
||||
}
|
||||
|
||||
// If cache miss or expired, check actual status
|
||||
statusStr, err := as.StatusServer(serviceName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -65,7 +62,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update status cache for this service before starting
|
||||
as.statusCache.UpdateStatus(serviceName, model.StatusStarting)
|
||||
|
||||
_, err = as.StartServer(serviceName)
|
||||
@@ -77,7 +73,6 @@ func (as *ServiceControlService) ServiceControlStartServer(ctx *fiber.Ctx) (stri
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -89,7 +84,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update status cache for this service before stopping
|
||||
as.statusCache.UpdateStatus(serviceName, model.StatusStopping)
|
||||
|
||||
_, err = as.StopServer(serviceName)
|
||||
@@ -101,7 +95,6 @@ func (as *ServiceControlService) ServiceControlStopServer(ctx *fiber.Ctx) (strin
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -113,7 +106,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Update status cache for this service before restarting
|
||||
as.statusCache.UpdateStatus(serviceName, model.StatusRestarting)
|
||||
|
||||
_, err = as.RestartServer(serviceName)
|
||||
@@ -125,7 +117,6 @@ func (as *ServiceControlService) ServiceControlRestartServer(ctx *fiber.Ctx) (st
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
@@ -135,20 +126,16 @@ func (as *ServiceControlService) StatusServer(serviceName string) (string, error
|
||||
return as.windowsService.Status(context.Background(), serviceName)
|
||||
}
|
||||
|
||||
// GetCachedStatus gets the cached status for a service name without requiring fiber context
|
||||
func (as *ServiceControlService) GetCachedStatus(serviceName string) (string, error) {
|
||||
// Try to get status from cache
|
||||
if status, shouldCheck := as.statusCache.GetStatus(serviceName); !shouldCheck {
|
||||
return status.String(), nil
|
||||
}
|
||||
|
||||
// If cache miss or expired, check actual status
|
||||
statusStr, err := as.StatusServer(serviceName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse and update cache with new status
|
||||
status := model.ParseServiceStatus(statusStr)
|
||||
as.statusCache.UpdateStatus(serviceName, status)
|
||||
return status.String(), nil
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
type ServiceManager struct {
|
||||
executor *command.CommandExecutor
|
||||
executor *command.CommandExecutor
|
||||
psExecutor *command.CommandExecutor
|
||||
}
|
||||
|
||||
@@ -24,17 +24,14 @@ func NewServiceManager() *ServiceManager {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) ManageService(serviceName, action string) (string, error) {
|
||||
// Run NSSM command through PowerShell to ensure elevation
|
||||
output, err := s.psExecutor.ExecuteWithOutput("-nologo", "-noprofile", ".\\nssm", action, serviceName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Clean up output by removing null bytes and trimming whitespace
|
||||
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
|
||||
// Remove \r\n from status strings
|
||||
cleaned = strings.TrimSuffix(cleaned, "\r\n")
|
||||
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
@@ -51,11 +48,9 @@ func (s *ServiceManager) Stop(serviceName string) (string, error) {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) Restart(serviceName string) (string, error) {
|
||||
// First stop the service
|
||||
if _, err := s.Stop(serviceName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Then start it again
|
||||
return s.Start(serviceName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
|
||||
eg, gCtx := errgroup.WithContext(ctx.UserContext())
|
||||
|
||||
// Get Summary Stats (Peak/Avg Players, Total Sessions)
|
||||
eg.Go(func() error {
|
||||
summary, err := s.repository.GetSummaryStats(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -61,7 +60,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Total Playtime
|
||||
eg.Go(func() error {
|
||||
playtime, err := s.repository.GetTotalPlaytime(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -74,7 +72,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Player Count Over Time
|
||||
eg.Go(func() error {
|
||||
playerCount, err := s.repository.GetPlayerCountOverTime(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -87,7 +84,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Session Types
|
||||
eg.Go(func() error {
|
||||
sessionTypes, err := s.repository.GetSessionTypes(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -100,7 +96,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Daily Activity
|
||||
eg.Go(func() error {
|
||||
dailyActivity, err := s.repository.GetDailyActivity(gCtx, filter)
|
||||
if err != nil {
|
||||
@@ -113,7 +108,6 @@ func (s *StateHistoryService) GetStatistics(ctx *fiber.Ctx, filter *model.StateH
|
||||
return nil
|
||||
})
|
||||
|
||||
// Get Recent Sessions
|
||||
eg.Go(func() error {
|
||||
recentSessions, err := s.repository.GetRecentSessions(gCtx, filter)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,10 +6,15 @@ import (
|
||||
"acc-server-manager/local/utl/command"
|
||||
"acc-server-manager/local/utl/env"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"acc-server-manager/local/utl/security"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -17,17 +22,25 @@ const (
|
||||
)
|
||||
|
||||
type SteamService struct {
|
||||
executor *command.CommandExecutor
|
||||
repository *repository.SteamCredentialsRepository
|
||||
executor *command.CommandExecutor
|
||||
repository *repository.SteamCredentialsRepository
|
||||
tfaManager *model.Steam2FAManager
|
||||
pathValidator *security.PathValidator
|
||||
downloadVerifier *security.DownloadVerifier
|
||||
}
|
||||
|
||||
func NewSteamService(repository *repository.SteamCredentialsRepository) *SteamService {
|
||||
func NewSteamService(repository *repository.SteamCredentialsRepository, tfaManager *model.Steam2FAManager) *SteamService {
|
||||
baseExecutor := &command.CommandExecutor{
|
||||
ExePath: "powershell",
|
||||
LogOutput: true,
|
||||
}
|
||||
|
||||
return &SteamService{
|
||||
executor: &command.CommandExecutor{
|
||||
ExePath: "powershell",
|
||||
LogOutput: true,
|
||||
},
|
||||
repository: repository,
|
||||
executor: baseExecutor,
|
||||
repository: repository,
|
||||
tfaManager: tfaManager,
|
||||
pathValidator: security.NewPathValidator(),
|
||||
downloadVerifier: security.NewDownloadVerifier(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,114 +55,272 @@ func (s *SteamService) SaveCredentials(ctx context.Context, creds *model.SteamCr
|
||||
return s.repository.Save(ctx, creds)
|
||||
}
|
||||
|
||||
func (s *SteamService) ensureSteamCMD(ctx context.Context) error {
|
||||
// Get SteamCMD path from environment variable
|
||||
func (s *SteamService) ensureSteamCMD(_ context.Context) error {
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
steamCMDDir := filepath.Dir(steamCMDPath)
|
||||
|
||||
// Check if SteamCMD exists
|
||||
if _, err := os.Stat(steamCMDPath); !os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
if err := os.MkdirAll(steamCMDDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create SteamCMD directory: %v", err)
|
||||
}
|
||||
|
||||
// Download and install SteamCMD
|
||||
logging.Info("Downloading SteamCMD...")
|
||||
if err := s.executor.Execute("-Command",
|
||||
"Invoke-WebRequest -Uri 'https://steamcdn-a.akamaihd.net/client/installer/steamcmd.zip' -OutFile 'steamcmd.zip'"); err != nil {
|
||||
steamCMDZip := filepath.Join(steamCMDDir, "steamcmd.zip")
|
||||
if err := s.downloadVerifier.VerifyAndDownload(
|
||||
"https://steamcdn-a.akamaihd.net/client/installer/steamcmd.zip",
|
||||
steamCMDZip,
|
||||
""); err != nil {
|
||||
return fmt.Errorf("failed to download SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
// Extract SteamCMD
|
||||
logging.Info("Extracting SteamCMD...")
|
||||
if err := s.executor.Execute("-Command",
|
||||
fmt.Sprintf("Expand-Archive -Path 'steamcmd.zip' -DestinationPath '%s'", steamCMDDir)); err != nil {
|
||||
return fmt.Errorf("failed to extract SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
// Clean up zip file
|
||||
os.Remove("steamcmd.zip")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) InstallServer(ctx context.Context, installPath string) error {
|
||||
func (s *SteamService) InstallServerWithWebSocket(ctx context.Context, installPath string, serverID *uuid.UUID, wsService *WebSocketService) error {
|
||||
if err := s.ensureSteamCMD(ctx); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert to absolute path and ensure proper Windows path format
|
||||
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
|
||||
return fmt.Errorf("invalid installation path: %v", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(installPath)
|
||||
if err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
|
||||
return fmt.Errorf("failed to get absolute path: %v", err)
|
||||
}
|
||||
absPath = filepath.Clean(absPath)
|
||||
|
||||
// Ensure install path exists
|
||||
if err := os.MkdirAll(absPath, 0755); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
|
||||
return fmt.Errorf("failed to create install directory: %v", err)
|
||||
}
|
||||
|
||||
// Get Steam credentials
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
|
||||
|
||||
creds, err := s.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
|
||||
return fmt.Errorf("failed to get Steam credentials: %v", err)
|
||||
}
|
||||
|
||||
// Get SteamCMD path from environment variable
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
|
||||
// Build SteamCMD command
|
||||
args := []string{
|
||||
"-nologo",
|
||||
"-noprofile",
|
||||
steamCMDPath,
|
||||
steamCMDArgs := []string{
|
||||
"+force_install_dir", absPath,
|
||||
"+login",
|
||||
}
|
||||
|
||||
if creds != nil && creds.Username != "" {
|
||||
args = append(args, creds.Username)
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false)
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Username)
|
||||
if creds.Password != "" {
|
||||
args = append(args, creds.Password)
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Password)
|
||||
}
|
||||
} else {
|
||||
args = append(args, "anonymous")
|
||||
wsService.BroadcastSteamOutput(*serverID, "Using anonymous Steam login", false)
|
||||
steamCMDArgs = append(steamCMDArgs, "anonymous")
|
||||
}
|
||||
|
||||
args = append(args,
|
||||
steamCMDArgs = append(steamCMDArgs,
|
||||
"+app_update", ACCServerAppID,
|
||||
"validate",
|
||||
"+quit",
|
||||
)
|
||||
|
||||
// Run SteamCMD
|
||||
logging.Info("Installing ACC server to %s...", absPath)
|
||||
if err := s.executor.Execute(args...); err != nil {
|
||||
args := steamCMDArgs
|
||||
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
callbackConfig := &command.CallbackConfig{
|
||||
OnOutput: func(serverID uuid.UUID, output string, isError bool) {
|
||||
wsService.BroadcastSteamOutput(serverID, output, isError)
|
||||
},
|
||||
OnCommand: func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string) {
|
||||
if completed {
|
||||
if success {
|
||||
wsService.BroadcastSteamOutput(serverID, "Command completed successfully", false)
|
||||
} else {
|
||||
wsService.BroadcastSteamOutput(serverID, fmt.Sprintf("Command failed: %s", error), true)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
callbackInteractiveExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbackConfig, *serverID)
|
||||
callbackInteractiveExecutor.ExePath = steamCMDPath
|
||||
|
||||
if err := callbackInteractiveExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
|
||||
}
|
||||
return fmt.Errorf("failed to run SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
// Add a delay to allow Steam to properly cleanup
|
||||
logging.Info("Waiting for Steam operations to complete...")
|
||||
if err := s.executor.Execute("-Command", "Start-Sleep -Seconds 5"); err != nil {
|
||||
logging.Warn("Failed to wait after Steam operations: %v", err)
|
||||
}
|
||||
wsService.BroadcastSteamOutput(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
|
||||
|
||||
wsService.BroadcastSteamOutput(*serverID, "Waiting for Steam operations to complete...", false)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Verify installation
|
||||
exePath := filepath.Join(absPath, "server", "accServer.exe")
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
|
||||
|
||||
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("server installation failed: accServer.exe not found in %s", absPath)
|
||||
wsService.BroadcastSteamOutput(*serverID, "accServer.exe not found, checking directory contents...", false)
|
||||
|
||||
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", absPath), false)
|
||||
for _, entry := range entries {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
|
||||
}
|
||||
}
|
||||
|
||||
serverDir := filepath.Join(absPath, "server")
|
||||
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
|
||||
for _, entry := range entries {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
|
||||
}
|
||||
} else {
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true)
|
||||
}
|
||||
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true)
|
||||
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
|
||||
}
|
||||
|
||||
logging.Info("Server installation completed successfully")
|
||||
wsService.BroadcastSteamOutput(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) UpdateServer(ctx context.Context, installPath string) error {
|
||||
return s.InstallServer(ctx, installPath) // Same process as install
|
||||
func (s *SteamService) InstallServerWithCallbacks(ctx context.Context, installPath string, serverID *uuid.UUID, outputCallback command.OutputCallback) error {
|
||||
if err := s.ensureSteamCMD(ctx); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Error ensuring SteamCMD: %v", err), true)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.pathValidator.ValidateInstallPath(installPath); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Invalid installation path: %v", err), true)
|
||||
return fmt.Errorf("invalid installation path: %v", err)
|
||||
}
|
||||
|
||||
absPath, err := filepath.Abs(installPath)
|
||||
if err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Failed to get absolute path: %v", err), true)
|
||||
return fmt.Errorf("failed to get absolute path: %v", err)
|
||||
}
|
||||
absPath = filepath.Clean(absPath)
|
||||
|
||||
if err := os.MkdirAll(absPath, 0755); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Failed to create install directory: %v", err), true)
|
||||
return fmt.Errorf("failed to create install directory: %v", err)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Installation directory prepared: %s", absPath), false)
|
||||
|
||||
creds, err := s.GetCredentials(ctx)
|
||||
if err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Failed to get Steam credentials: %v", err), true)
|
||||
return fmt.Errorf("failed to get Steam credentials: %v", err)
|
||||
}
|
||||
|
||||
steamCMDPath := env.GetSteamCMDPath()
|
||||
|
||||
steamCMDArgs := []string{
|
||||
"+force_install_dir", absPath,
|
||||
"+login",
|
||||
}
|
||||
|
||||
if creds != nil && creds.Username != "" {
|
||||
outputCallback(*serverID, fmt.Sprintf("Using Steam credentials for user: %s", creds.Username), false)
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Username)
|
||||
if creds.Password != "" {
|
||||
steamCMDArgs = append(steamCMDArgs, creds.Password)
|
||||
}
|
||||
} else {
|
||||
outputCallback(*serverID, "Using anonymous Steam login", false)
|
||||
steamCMDArgs = append(steamCMDArgs, "anonymous")
|
||||
}
|
||||
|
||||
steamCMDArgs = append(steamCMDArgs,
|
||||
"+app_update", ACCServerAppID,
|
||||
"validate",
|
||||
"+quit",
|
||||
)
|
||||
|
||||
args := steamCMDArgs
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Starting SteamCMD: %s %s", steamCMDPath, strings.Join(args, " ")), false)
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
callbacks := &command.CallbackConfig{
|
||||
OnOutput: outputCallback,
|
||||
}
|
||||
|
||||
callbackExecutor := command.NewCallbackInteractiveCommandExecutor(s.executor, s.tfaManager, callbacks, *serverID)
|
||||
callbackExecutor.ExePath = steamCMDPath
|
||||
|
||||
if err := callbackExecutor.ExecuteInteractive(timeoutCtx, serverID, args...); err != nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("SteamCMD execution failed: %v", err), true)
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Errorf("SteamCMD operation timed out after 15 minutes - this usually means Steam Guard confirmation is required")
|
||||
}
|
||||
return fmt.Errorf("failed to run SteamCMD: %v", err)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, "SteamCMD execution completed successfully, proceeding with verification...", false)
|
||||
|
||||
outputCallback(*serverID, "Waiting for Steam operations to complete...", false)
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
exePath := filepath.Join(absPath, "server", "accServer.exe")
|
||||
outputCallback(*serverID, fmt.Sprintf("Checking for ACC server executable at: %s", exePath), false)
|
||||
|
||||
if _, err := os.Stat(exePath); os.IsNotExist(err) {
|
||||
outputCallback(*serverID, "accServer.exe not found, checking directory contents...", false)
|
||||
|
||||
if entries, dirErr := os.ReadDir(absPath); dirErr == nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", absPath), false)
|
||||
for _, entry := range entries {
|
||||
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
|
||||
}
|
||||
}
|
||||
|
||||
serverDir := filepath.Join(absPath, "server")
|
||||
if entries, dirErr := os.ReadDir(serverDir); dirErr == nil {
|
||||
outputCallback(*serverID, fmt.Sprintf("Contents of %s:", serverDir), false)
|
||||
for _, entry := range entries {
|
||||
outputCallback(*serverID, fmt.Sprintf(" - %s (dir: %v)", entry.Name(), entry.IsDir()), false)
|
||||
}
|
||||
} else {
|
||||
outputCallback(*serverID, fmt.Sprintf("Server directory %s does not exist or cannot be read: %v", serverDir, dirErr), true)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Server installation failed: accServer.exe not found in %s", exePath), true)
|
||||
return fmt.Errorf("server installation failed: accServer.exe not found in %s", exePath)
|
||||
}
|
||||
|
||||
outputCallback(*serverID, fmt.Sprintf("Server installation completed successfully - accServer.exe found at %s", exePath), false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SteamService) UninstallServer(installPath string) error {
|
||||
|
||||
185
local/service/websocket.go
Normal file
185
local/service/websocket.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebSocketConnection struct {
|
||||
conn *websocket.Conn
|
||||
serverID *uuid.UUID
|
||||
userID *uuid.UUID
|
||||
}
|
||||
|
||||
type WebSocketService struct {
|
||||
connections sync.Map
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewWebSocketService() *WebSocketService {
|
||||
return &WebSocketService{}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
|
||||
wsConn := &WebSocketConnection{
|
||||
conn: conn,
|
||||
userID: userID,
|
||||
}
|
||||
ws.connections.Store(connID, wsConn)
|
||||
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) RemoveConnection(connID string) {
|
||||
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
wsConn.conn.Close()
|
||||
}
|
||||
}
|
||||
logging.Info("WebSocket connection removed: %s", connID)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
|
||||
if conn, exists := ws.connections.Load(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
wsConn.serverID = &serverID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
|
||||
stepMsg := model.StepMessage{
|
||||
Step: step,
|
||||
Status: status,
|
||||
Message: message,
|
||||
Error: errorMsg,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeStep,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: stepMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
|
||||
steamMsg := model.SteamOutputMessage{
|
||||
Output: output,
|
||||
IsError: isError,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeSteamOutput,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: steamMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
|
||||
errorMsg := model.ErrorMessage{
|
||||
Error: error,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeError,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: errorMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
|
||||
completeMsg := model.CompleteMessage{
|
||||
ServerID: serverID,
|
||||
Success: success,
|
||||
Message: message,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeComplete,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: completeMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logging.Error("Failed to marshal WebSocket message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sentToAssociatedConnections := false
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if wsConn.serverID != nil && *wsConn.serverID == serverID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
} else {
|
||||
sentToAssociatedConnections = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if !sentToAssociatedConnections && (message.Type == model.MessageTypeStep || message.Type == model.MessageTypeError || message.Type == model.MessageTypeComplete) {
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logging.Error("Failed to marshal WebSocket message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if wsConn.userID != nil && *wsConn.userID == userID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) GetActiveConnections() int {
|
||||
count := 0
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
@@ -23,34 +23,25 @@ func NewWindowsService() *WindowsService {
|
||||
}
|
||||
}
|
||||
|
||||
// executeNSSM runs an NSSM command through PowerShell with elevation
|
||||
func (s *WindowsService) ExecuteNSSM(ctx context.Context, args ...string) (string, error) {
|
||||
// Get NSSM path from environment variable
|
||||
nssmPath := env.GetNSSMPath()
|
||||
|
||||
// Prepend NSSM path to arguments
|
||||
nssmArgs := append([]string{"-NoProfile", "-NonInteractive", "-Command", "& " + nssmPath}, args...)
|
||||
|
||||
output, err := s.executor.ExecuteWithOutput(nssmArgs...)
|
||||
if err != nil {
|
||||
// Log the full command and error for debugging
|
||||
logging.Error("NSSM command failed: powershell %s", strings.Join(nssmArgs, " "))
|
||||
logging.Error("NSSM error output: %s", output)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Clean up output by removing null bytes and trimming whitespace
|
||||
cleaned := strings.TrimSpace(strings.ReplaceAll(output, "\x00", ""))
|
||||
// Remove \r\n from status strings
|
||||
cleaned = strings.TrimSuffix(cleaned, "\r\n")
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
// Service Installation/Configuration Methods
|
||||
|
||||
func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
|
||||
// Ensure paths are absolute and properly formatted for Windows
|
||||
absExecPath, err := filepath.Abs(execPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get absolute path for executable: %v", err)
|
||||
@@ -63,30 +54,24 @@ func (s *WindowsService) CreateService(ctx context.Context, serviceName, execPat
|
||||
}
|
||||
absWorkingDir = filepath.Clean(absWorkingDir)
|
||||
|
||||
// Log the paths being used
|
||||
logging.Info("Creating service '%s' with:", serviceName)
|
||||
logging.Info(" Executable: %s", absExecPath)
|
||||
logging.Info(" Working Directory: %s", absWorkingDir)
|
||||
|
||||
// First remove any existing service with the same name
|
||||
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
|
||||
|
||||
// Install service
|
||||
if _, err := s.ExecuteNSSM(ctx, "install", serviceName, absExecPath); err != nil {
|
||||
return fmt.Errorf("failed to install service: %v", err)
|
||||
}
|
||||
|
||||
// Set arguments if provided
|
||||
if len(args) > 0 {
|
||||
cmdArgs := append([]string{"set", serviceName, "AppParameters"}, args...)
|
||||
if _, err := s.ExecuteNSSM(ctx, cmdArgs...); err != nil {
|
||||
// Try to clean up on failure
|
||||
s.ExecuteNSSM(ctx, "remove", serviceName, "confirm")
|
||||
return fmt.Errorf("failed to set arguments: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify service was created
|
||||
if _, err := s.ExecuteNSSM(ctx, "get", serviceName, "Application"); err != nil {
|
||||
return fmt.Errorf("service creation verification failed: %v", err)
|
||||
}
|
||||
@@ -105,17 +90,13 @@ func (s *WindowsService) DeleteService(ctx context.Context, serviceName string)
|
||||
}
|
||||
|
||||
func (s *WindowsService) UpdateService(ctx context.Context, serviceName, execPath, workingDir string, args []string) error {
|
||||
// First remove the existing service
|
||||
if err := s.DeleteService(ctx, serviceName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then create it again with new parameters
|
||||
return s.CreateService(ctx, serviceName, execPath, workingDir, args)
|
||||
}
|
||||
|
||||
// Service Control Methods
|
||||
|
||||
func (s *WindowsService) Status(ctx context.Context, serviceName string) (string, error) {
|
||||
return s.ExecuteNSSM(ctx, "status", serviceName)
|
||||
}
|
||||
@@ -129,11 +110,9 @@ func (s *WindowsService) Stop(ctx context.Context, serviceName string) (string,
|
||||
}
|
||||
|
||||
func (s *WindowsService) Restart(ctx context.Context, serviceName string) (string, error) {
|
||||
// First stop the service
|
||||
if _, err := s.Stop(ctx, serviceName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Then start it again
|
||||
return s.Start(ctx, serviceName)
|
||||
}
|
||||
|
||||
64
local/utl/audit/audit.go
Normal file
64
local/utl/audit/audit.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditAction string
|
||||
|
||||
const (
|
||||
ActionLogin AuditAction = "LOGIN"
|
||||
ActionLogout AuditAction = "LOGOUT"
|
||||
ActionServerCreate AuditAction = "SERVER_CREATE"
|
||||
ActionServerUpdate AuditAction = "SERVER_UPDATE"
|
||||
ActionServerDelete AuditAction = "SERVER_DELETE"
|
||||
ActionServerStart AuditAction = "SERVER_START"
|
||||
ActionServerStop AuditAction = "SERVER_STOP"
|
||||
ActionUserCreate AuditAction = "USER_CREATE"
|
||||
ActionUserUpdate AuditAction = "USER_UPDATE"
|
||||
ActionUserDelete AuditAction = "USER_DELETE"
|
||||
ActionConfigUpdate AuditAction = "CONFIG_UPDATE"
|
||||
ActionSteamAuth AuditAction = "STEAM_AUTH"
|
||||
ActionPermissionGrant AuditAction = "PERMISSION_GRANT"
|
||||
ActionPermissionRevoke AuditAction = "PERMISSION_REVOKE"
|
||||
)
|
||||
|
||||
type AuditEntry struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Action AuditAction `json:"action"`
|
||||
Resource string `json:"resource"`
|
||||
Details string `json:"details"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func LogAction(ctx context.Context, userID, username string, action AuditAction, resource, details, ipAddress, userAgent string, success bool) {
|
||||
logging.InfoWithContext("AUDIT", "User %s (%s) performed %s on %s from %s - Success: %t - Details: %s",
|
||||
username, userID, action, resource, ipAddress, success, details)
|
||||
}
|
||||
|
||||
func LogAuthAction(ctx context.Context, username, ipAddress, userAgent string, success bool, details string) {
|
||||
action := ActionLogin
|
||||
if !success {
|
||||
details = "Failed: " + details
|
||||
}
|
||||
|
||||
LogAction(ctx, "", username, action, "authentication", details, ipAddress, userAgent, success)
|
||||
}
|
||||
|
||||
func LogServerAction(ctx context.Context, userID, username string, action AuditAction, serverID, ipAddress, userAgent string, success bool, details string) {
|
||||
LogAction(ctx, userID, username, action, "server:"+serverID, details, ipAddress, userAgent, success)
|
||||
}
|
||||
|
||||
func LogUserManagementAction(ctx context.Context, adminUserID, adminUsername string, action AuditAction, targetUserID, ipAddress, userAgent string, success bool, details string) {
|
||||
LogAction(ctx, adminUserID, adminUsername, action, "user:"+targetUserID, details, ipAddress, userAgent, success)
|
||||
}
|
||||
|
||||
func LogConfigAction(ctx context.Context, userID, username string, configType, ipAddress, userAgent string, success bool, details string) {
|
||||
LogAction(ctx, userID, username, ActionConfigUpdate, "config:"+configType, details, ipAddress, userAgent, success)
|
||||
}
|
||||
12
local/utl/cache/cache.go
vendored
12
local/utl/cache/cache.go
vendored
@@ -9,26 +9,22 @@ import (
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
Expiration int64
|
||||
}
|
||||
|
||||
// InMemoryCache is a thread-safe in-memory cache
|
||||
type InMemoryCache struct {
|
||||
items map[string]CacheItem
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInMemoryCache creates and returns a new InMemoryCache instance
|
||||
func NewInMemoryCache() *InMemoryCache {
|
||||
return &InMemoryCache{
|
||||
items: make(map[string]CacheItem),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache with an expiration duration (in seconds)
|
||||
func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -44,7 +40,6 @@ func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duratio
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *InMemoryCache) Get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
@@ -55,24 +50,18 @@ func (c *InMemoryCache) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
|
||||
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
|
||||
// Item has expired, but don't delete here to avoid lock upgrade.
|
||||
// It will be overwritten on the next Set.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
func (c *InMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves an item from the cache. If the item is not found, it
|
||||
// calls the provided function to get the value, sets it in the cache, and
|
||||
// returns it.
|
||||
func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) {
|
||||
if cached, found := c.Get(key); found {
|
||||
if value, ok := cached.(T); ok {
|
||||
@@ -90,7 +79,6 @@ func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetch
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Start initializes the cache and provides it to the DI container.
|
||||
func Start(di *dig.Container) {
|
||||
cache := NewInMemoryCache()
|
||||
err := di.Provide(func() *InMemoryCache {
|
||||
|
||||
245
local/utl/command/callback_executor.go
Normal file
245
local/utl/command/callback_executor.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type CallbackInteractiveCommandExecutor struct {
|
||||
*InteractiveCommandExecutor
|
||||
callbacks *CallbackConfig
|
||||
serverID uuid.UUID
|
||||
}
|
||||
|
||||
func NewCallbackInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager, callbacks *CallbackConfig, serverID uuid.UUID) *CallbackInteractiveCommandExecutor {
|
||||
if callbacks == nil {
|
||||
callbacks = DefaultCallbackConfig()
|
||||
}
|
||||
|
||||
return &CallbackInteractiveCommandExecutor{
|
||||
InteractiveCommandExecutor: &InteractiveCommandExecutor{
|
||||
CommandExecutor: baseExecutor,
|
||||
tfaManager: tfaManager,
|
||||
},
|
||||
callbacks: callbacks,
|
||||
serverID: serverID,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *CallbackInteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, e.ExePath, args...)
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %v", err)
|
||||
}
|
||||
defer stdin.Close()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
defer stderr.Close()
|
||||
|
||||
logging.Info("Executing interactive command with callbacks: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
|
||||
e.callbacks.OnCommand(e.serverID, e.ExePath, args, false, false, "")
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Starting command: %s %s", e.ExePath, strings.Join(args, " ")), false)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to start command: %v", err), true)
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
}
|
||||
|
||||
outputDone := make(chan error, 1)
|
||||
cmdDone := make(chan error, 1)
|
||||
|
||||
go e.monitorOutputWithCallbacks(ctx, stdout, stderr, serverID, outputDone)
|
||||
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
var cmdErr, outputErr error
|
||||
completedCount := 0
|
||||
|
||||
for completedCount < 2 {
|
||||
select {
|
||||
case cmdErr = <-cmdDone:
|
||||
completedCount++
|
||||
logging.Info("Command execution completed")
|
||||
e.callbacks.OnOutput(e.serverID, "Command execution completed", false)
|
||||
case outputErr = <-outputDone:
|
||||
completedCount++
|
||||
logging.Info("Output monitoring completed")
|
||||
case <-ctx.Done():
|
||||
e.callbacks.OnOutput(e.serverID, "Command execution cancelled", true)
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if outputErr != nil {
|
||||
logging.Warn("Output monitoring error: %v", outputErr)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Output monitoring error: %v", outputErr), true)
|
||||
}
|
||||
|
||||
success := cmdErr == nil
|
||||
errorMsg := ""
|
||||
if cmdErr != nil {
|
||||
errorMsg = cmdErr.Error()
|
||||
}
|
||||
e.callbacks.OnCommand(e.serverID, e.ExePath, args, true, success, errorMsg)
|
||||
|
||||
return cmdErr
|
||||
}
|
||||
|
||||
func (e *CallbackInteractiveCommandExecutor) monitorOutputWithCallbacks(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
|
||||
defer func() {
|
||||
select {
|
||||
case done <- nil:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
stdoutScanner := bufio.NewScanner(stdout)
|
||||
stderrScanner := bufio.NewScanner(stderr)
|
||||
|
||||
outputChan := make(chan outputLine, 100)
|
||||
readersDone := make(chan struct{}, 2)
|
||||
|
||||
steamConsoleStarted := false
|
||||
tfaRequestCreated := false
|
||||
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stdoutScanner.Scan() {
|
||||
line := stdoutScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDOUT: %s", line)
|
||||
}
|
||||
e.callbacks.OnOutput(e.serverID, line, false)
|
||||
|
||||
select {
|
||||
case outputChan <- outputLine{text: line, isError: false}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stdoutScanner.Err(); err != nil {
|
||||
logging.Warn("Stdout scanner error: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stdout scanner error: %v", err), true)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stderrScanner.Scan() {
|
||||
line := stderrScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDERR: %s", line)
|
||||
}
|
||||
e.callbacks.OnOutput(e.serverID, line, true)
|
||||
|
||||
select {
|
||||
case outputChan <- outputLine{text: line, isError: true}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stderrScanner.Err(); err != nil {
|
||||
logging.Warn("Stderr scanner error: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Stderr scanner error: %v", err), true)
|
||||
}
|
||||
}()
|
||||
|
||||
readersFinished := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
return
|
||||
case <-readersDone:
|
||||
readersFinished++
|
||||
if readersFinished == 2 {
|
||||
close(outputChan)
|
||||
for lineData := range outputChan {
|
||||
if e.is2FAPrompt(lineData.text) {
|
||||
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
case lineData, ok := <-outputChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
lowerLine := strings.ToLower(lineData.text)
|
||||
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
|
||||
steamConsoleStarted = true
|
||||
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
|
||||
e.callbacks.OnOutput(e.serverID, "Steam Console Client startup detected", false)
|
||||
}
|
||||
|
||||
if e.is2FAPrompt(lineData.text) {
|
||||
if !tfaRequestCreated {
|
||||
e.callbacks.OnOutput(e.serverID, "2FA prompt detected - waiting for user confirmation", false)
|
||||
if err := e.handle2FAPrompt(ctx, lineData.text, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
|
||||
if tfaRequestCreated && e.isSteamContinuing(lineData.text) {
|
||||
logging.Info("Steam CMD appears to have continued after 2FA confirmation")
|
||||
e.callbacks.OnOutput(e.serverID, "Steam CMD continued after 2FA confirmation", false)
|
||||
e.autoCompletePendingRequests(serverID)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
if steamConsoleStarted && !tfaRequestCreated {
|
||||
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
|
||||
e.callbacks.OnOutput(e.serverID, "Waiting for Steam Guard 2FA confirmation...", false)
|
||||
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
|
||||
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
|
||||
e.callbacks.OnOutput(e.serverID, fmt.Sprintf("Failed to handle Steam Guard 2FA prompt: %v", err), true)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type outputLine struct {
|
||||
text string
|
||||
isError bool
|
||||
}
|
||||
19
local/utl/command/callbacks.go
Normal file
19
local/utl/command/callbacks.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package command
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type OutputCallback func(serverID uuid.UUID, output string, isError bool)
|
||||
|
||||
type CommandCallback func(serverID uuid.UUID, command string, args []string, completed bool, success bool, error string)
|
||||
|
||||
type CallbackConfig struct {
|
||||
OnOutput OutputCallback
|
||||
OnCommand CommandCallback
|
||||
}
|
||||
|
||||
func DefaultCallbackConfig() *CallbackConfig {
|
||||
return &CallbackConfig{
|
||||
OnOutput: func(uuid.UUID, string, bool) {},
|
||||
OnCommand: func(uuid.UUID, string, []string, bool, bool, string) {},
|
||||
}
|
||||
}
|
||||
@@ -8,17 +8,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CommandExecutor provides a base structure for executing commands
|
||||
type CommandExecutor struct {
|
||||
// Base executable path
|
||||
ExePath string
|
||||
// Working directory for commands
|
||||
WorkDir string
|
||||
// Whether to capture and log output
|
||||
ExePath string
|
||||
WorkDir string
|
||||
LogOutput bool
|
||||
}
|
||||
|
||||
// CommandBuilder helps build command arguments
|
||||
type CommandBuilder struct {
|
||||
args []string
|
||||
}
|
||||
@@ -48,10 +43,9 @@ func (b *CommandBuilder) Build() []string {
|
||||
return b.args
|
||||
}
|
||||
|
||||
// Execute runs a command with the given arguments
|
||||
func (e *CommandExecutor) Execute(args ...string) error {
|
||||
cmd := exec.Command(e.ExePath, args...)
|
||||
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
@@ -65,15 +59,13 @@ func (e *CommandExecutor) Execute(args ...string) error {
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// ExecuteWithBuilder runs a command using a CommandBuilder
|
||||
func (e *CommandExecutor) ExecuteWithBuilder(builder *CommandBuilder) error {
|
||||
return e.Execute(builder.Build()...)
|
||||
}
|
||||
|
||||
// ExecuteWithOutput runs a command and returns its output
|
||||
func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
|
||||
cmd := exec.Command(e.ExePath, args...)
|
||||
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
@@ -83,10 +75,9 @@ func (e *CommandExecutor) ExecuteWithOutput(args ...string) (string, error) {
|
||||
return string(output), err
|
||||
}
|
||||
|
||||
// ExecuteWithEnv runs a command with custom environment variables
|
||||
func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
|
||||
cmd := exec.Command(e.ExePath, args...)
|
||||
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
@@ -100,4 +91,4 @@ func (e *CommandExecutor) ExecuteWithEnv(env []string, args ...string) error {
|
||||
|
||||
logging.Info("Executing command: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
return cmd.Run()
|
||||
}
|
||||
}
|
||||
|
||||
324
local/utl/command/interactive_executor.go
Normal file
324
local/utl/command/interactive_executor.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type InteractiveCommandExecutor struct {
|
||||
*CommandExecutor
|
||||
tfaManager *model.Steam2FAManager
|
||||
}
|
||||
|
||||
func NewInteractiveCommandExecutor(baseExecutor *CommandExecutor, tfaManager *model.Steam2FAManager) *InteractiveCommandExecutor {
|
||||
return &InteractiveCommandExecutor{
|
||||
CommandExecutor: baseExecutor,
|
||||
tfaManager: tfaManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) ExecuteInteractive(ctx context.Context, serverID *uuid.UUID, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, e.ExePath, args...)
|
||||
|
||||
if e.WorkDir != "" {
|
||||
cmd.Dir = e.WorkDir
|
||||
}
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %v", err)
|
||||
}
|
||||
defer stdin.Close()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdout pipe: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stderr pipe: %v", err)
|
||||
}
|
||||
defer stderr.Close()
|
||||
|
||||
logging.Info("Executing interactive command: %s %s", e.ExePath, strings.Join(args, " "))
|
||||
|
||||
debugMode := os.Getenv("STEAMCMD_DEBUG") == "true"
|
||||
if debugMode {
|
||||
logging.Info("STEAMCMD_DEBUG mode enabled - will log all output and create proactive 2FA requests")
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start command: %v", err)
|
||||
}
|
||||
|
||||
outputDone := make(chan error, 1)
|
||||
cmdDone := make(chan error, 1)
|
||||
|
||||
go e.monitorOutput(ctx, stdout, stderr, serverID, outputDone)
|
||||
|
||||
go func() {
|
||||
cmdDone <- cmd.Wait()
|
||||
}()
|
||||
|
||||
var cmdErr, outputErr error
|
||||
completedCount := 0
|
||||
|
||||
for completedCount < 2 {
|
||||
select {
|
||||
case cmdErr = <-cmdDone:
|
||||
completedCount++
|
||||
logging.Info("Command execution completed")
|
||||
case outputErr = <-outputDone:
|
||||
completedCount++
|
||||
logging.Info("Output monitoring completed")
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if outputErr != nil {
|
||||
logging.Warn("Output monitoring error: %v", outputErr)
|
||||
}
|
||||
|
||||
return cmdErr
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) monitorOutput(ctx context.Context, stdout, stderr io.Reader, serverID *uuid.UUID, done chan error) {
|
||||
defer func() {
|
||||
select {
|
||||
case done <- nil:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
stdoutScanner := bufio.NewScanner(stdout)
|
||||
stderrScanner := bufio.NewScanner(stderr)
|
||||
|
||||
outputChan := make(chan string, 100)
|
||||
readersDone := make(chan struct{}, 2)
|
||||
|
||||
steamConsoleStarted := false
|
||||
tfaRequestCreated := false
|
||||
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stdoutScanner.Scan() {
|
||||
line := stdoutScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDOUT: %s", line)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(line), "steam") {
|
||||
logging.Info("STEAM_DEBUG: %s", line)
|
||||
}
|
||||
select {
|
||||
case outputChan <- line:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stdoutScanner.Err(); err != nil {
|
||||
logging.Warn("Stdout scanner error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() { readersDone <- struct{}{} }()
|
||||
for stderrScanner.Scan() {
|
||||
line := stderrScanner.Text()
|
||||
if e.LogOutput {
|
||||
logging.Info("STDERR: %s", line)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(line), "steam") {
|
||||
logging.Info("STEAM_DEBUG_ERR: %s", line)
|
||||
}
|
||||
select {
|
||||
case outputChan <- line:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stderrScanner.Err(); err != nil {
|
||||
logging.Warn("Stderr scanner error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
readersFinished := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
done <- ctx.Err()
|
||||
return
|
||||
case <-readersDone:
|
||||
readersFinished++
|
||||
if readersFinished == 2 {
|
||||
close(outputChan)
|
||||
for line := range outputChan {
|
||||
if e.is2FAPrompt(line) {
|
||||
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
case line, ok := <-outputChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
lowerLine := strings.ToLower(line)
|
||||
if strings.Contains(lowerLine, "steam console client") && strings.Contains(lowerLine, "valve corporation") {
|
||||
steamConsoleStarted = true
|
||||
logging.Info("Steam Console Client startup detected - will monitor for 2FA hang")
|
||||
}
|
||||
|
||||
if e.is2FAPrompt(line) {
|
||||
if !tfaRequestCreated {
|
||||
if err := e.handle2FAPrompt(ctx, line, serverID); err != nil {
|
||||
logging.Error("Failed to handle 2FA prompt: %v", err)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
}
|
||||
}
|
||||
|
||||
if tfaRequestCreated && e.isSteamContinuing(line) {
|
||||
logging.Info("Steam CMD appears to have continued after 2FA confirmation - auto-completing 2FA request")
|
||||
e.autoCompletePendingRequests(serverID)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
if steamConsoleStarted && !tfaRequestCreated {
|
||||
logging.Info("Steam Console started but no output for 15 seconds - likely waiting for Steam Guard 2FA")
|
||||
if err := e.handle2FAPrompt(ctx, "Steam CMD appears to be waiting for Steam Guard confirmation after startup", serverID); err != nil {
|
||||
logging.Error("Failed to handle Steam Guard 2FA prompt: %v", err)
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
tfaRequestCreated = true
|
||||
} else if !steamConsoleStarted {
|
||||
logging.Info("No output for 15 seconds (Steam Console not yet started)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) is2FAPrompt(line string) bool {
|
||||
twoFAKeywords := []string{
|
||||
"please enter your steam guard code",
|
||||
"steam guard",
|
||||
"two-factor",
|
||||
"authentication code",
|
||||
"please check your steam mobile app",
|
||||
"confirm in application",
|
||||
"enter the current code from your steam mobile app",
|
||||
"steam guard mobile authenticator",
|
||||
"waiting for user info",
|
||||
"login failure",
|
||||
"two factor code required",
|
||||
"enter steam guard code",
|
||||
"mobile authenticator code",
|
||||
"authenticator app",
|
||||
"guard code",
|
||||
"mobile app",
|
||||
"confirmation required",
|
||||
}
|
||||
|
||||
lowerLine := strings.ToLower(line)
|
||||
for _, keyword := range twoFAKeywords {
|
||||
if strings.Contains(lowerLine, keyword) {
|
||||
logging.Info("2FA keyword match found: '%s' in line: '%s'", keyword, line)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
waitingPatterns := []string{
|
||||
"waiting for",
|
||||
"please enter",
|
||||
"enter code",
|
||||
"code:",
|
||||
"authenticator:",
|
||||
}
|
||||
|
||||
for _, pattern := range waitingPatterns {
|
||||
if strings.Contains(lowerLine, pattern) {
|
||||
logging.Info("Potential 2FA waiting pattern found: '%s' in line: '%s'", pattern, line)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) isSteamContinuing(line string) bool {
|
||||
lowerLine := strings.ToLower(line)
|
||||
continuingPatterns := []string{
|
||||
"loading steam api",
|
||||
"logging in user",
|
||||
"waiting for client config",
|
||||
"waiting for user info",
|
||||
"update state",
|
||||
"success! app",
|
||||
"fully installed",
|
||||
}
|
||||
|
||||
for _, pattern := range continuingPatterns {
|
||||
if strings.Contains(lowerLine, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) autoCompletePendingRequests(serverID *uuid.UUID) {
|
||||
if e.tfaManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pendingRequests := e.tfaManager.GetPendingRequests()
|
||||
for _, req := range pendingRequests {
|
||||
if req.ServerID != nil && serverID != nil && *req.ServerID == *serverID {
|
||||
logging.Info("Auto-completing 2FA request %s for server %s", req.ID, serverID.String())
|
||||
if err := e.tfaManager.CompleteRequest(req.ID); err != nil {
|
||||
logging.Warn("Failed to auto-complete 2FA request %s: %v", req.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *InteractiveCommandExecutor) handle2FAPrompt(_ context.Context, promptLine string, serverID *uuid.UUID) error {
|
||||
logging.Info("2FA prompt detected: %s", promptLine)
|
||||
|
||||
request := e.tfaManager.CreateRequest(promptLine, serverID)
|
||||
logging.Info("Created 2FA request with ID: %s", request.ID)
|
||||
|
||||
timeout := 5 * time.Minute
|
||||
success, err := e.tfaManager.WaitForCompletion(request.ID, timeout)
|
||||
|
||||
if err != nil {
|
||||
logging.Error("2FA completion failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !success {
|
||||
logging.Error("2FA was not completed successfully")
|
||||
return fmt.Errorf("2FA authentication failed")
|
||||
}
|
||||
|
||||
logging.Info("2FA completed successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -25,6 +25,7 @@ type RouteGroups struct {
|
||||
StateHistory fiber.Router
|
||||
Membership fiber.Router
|
||||
System fiber.Router
|
||||
WebSocket fiber.Router
|
||||
}
|
||||
|
||||
func CheckError(err error) {
|
||||
@@ -78,12 +79,6 @@ func IndentJson(body []byte) ([]byte, error) {
|
||||
return unmarshaledBody.Bytes(), nil
|
||||
}
|
||||
|
||||
// ParseQueryFilter parses query parameters into a filter struct using reflection.
|
||||
// It supports various field types and uses struct tags to determine parsing behavior.
|
||||
// Supported tags:
|
||||
// - `query:"field_name"` - specifies the query parameter name
|
||||
// - `param:"param_name"` - specifies the path parameter name
|
||||
// - `time_format:"format"` - specifies the time format for parsing dates (default: RFC3339)
|
||||
func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
val := reflect.ValueOf(filter)
|
||||
if val.Kind() != reflect.Ptr || val.IsNil() {
|
||||
@@ -93,14 +88,12 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
elem := val.Elem()
|
||||
typ := elem.Type()
|
||||
|
||||
// Process all fields including embedded structs
|
||||
var processFields func(reflect.Value, reflect.Type) error
|
||||
processFields = func(val reflect.Value, typ reflect.Type) error {
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
fieldType := typ.Field(i)
|
||||
|
||||
// Handle embedded structs recursively
|
||||
if fieldType.Anonymous {
|
||||
if err := processFields(field, fieldType.Type); err != nil {
|
||||
return err
|
||||
@@ -108,12 +101,10 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field cannot be set
|
||||
if !field.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for param tag first (path parameters)
|
||||
if paramName := fieldType.Tag.Get("param"); paramName != "" {
|
||||
if err := parsePathParam(c, field, paramName); err != nil {
|
||||
return fmt.Errorf("error parsing path parameter %s: %v", paramName, err)
|
||||
@@ -121,15 +112,14 @@ func ParseQueryFilter(c *fiber.Ctx, filter interface{}) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Then check for query tag
|
||||
queryName := fieldType.Tag.Get("query")
|
||||
if queryName == "" {
|
||||
queryName = ToSnakeCase(fieldType.Name) // Default to snake_case of field name
|
||||
queryName = ToSnakeCase(fieldType.Name)
|
||||
}
|
||||
|
||||
queryVal := c.Query(queryName)
|
||||
if queryVal == "" {
|
||||
continue // Skip empty values
|
||||
continue
|
||||
}
|
||||
|
||||
if err := parseValue(field, queryVal, fieldType.Tag); err != nil {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "0.0.1"
|
||||
Version = "0.10.7"
|
||||
Prefix = "v1"
|
||||
Secret string
|
||||
SecretCode string
|
||||
@@ -18,7 +18,6 @@ var (
|
||||
|
||||
func Init() {
|
||||
godotenv.Load()
|
||||
// Fail fast if critical environment variables are missing
|
||||
Secret = getEnvRequired("APP_SECRET")
|
||||
SecretCode = getEnvRequired("APP_SECRET_CODE")
|
||||
EncryptionKey = getEnvRequired("ENCRYPTION_KEY")
|
||||
@@ -29,7 +28,6 @@ func Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// getEnv retrieves an environment variable or returns a fallback value.
|
||||
func getEnv(key, fallback string) string {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return value
|
||||
@@ -38,12 +36,10 @@ func getEnv(key, fallback string) string {
|
||||
return fallback
|
||||
}
|
||||
|
||||
// getEnvRequired retrieves an environment variable and fails if it's not set.
|
||||
// This should be used for critical configuration that must not have defaults.
|
||||
func getEnvRequired(key string) string {
|
||||
if value, exists := os.LookupEnv(key); exists && value != "" {
|
||||
return value
|
||||
}
|
||||
log.Fatalf("Required environment variable %s is not set or is empty", key)
|
||||
return "" // This line will never be reached due to log.Fatalf
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ func Start(di *dig.Container) {
|
||||
func Migrate(db *gorm.DB) {
|
||||
logging.Info("Migrating database")
|
||||
|
||||
// Run GORM AutoMigrate for all models
|
||||
err := db.AutoMigrate(
|
||||
&model.ServiceControlModel{},
|
||||
&model.Config{},
|
||||
@@ -52,7 +51,6 @@ func Migrate(db *gorm.DB) {
|
||||
|
||||
if err != nil {
|
||||
logging.Error("GORM AutoMigrate failed: %v", err)
|
||||
// Don't panic, just log the error as custom migrations may have handled this
|
||||
}
|
||||
|
||||
db.FirstOrCreate(&model.ServiceControlModel{ServiceControl: "Works"})
|
||||
@@ -63,10 +61,8 @@ func Migrate(db *gorm.DB) {
|
||||
func runMigrations(db *gorm.DB) {
|
||||
logging.Info("Running custom database migrations...")
|
||||
|
||||
// Migration 001: Password security upgrade
|
||||
if err := migrations.RunPasswordSecurityMigration(db); err != nil {
|
||||
logging.Error("Failed to run password security migration: %v", err)
|
||||
// Continue - this migration might not be needed for all setups
|
||||
}
|
||||
|
||||
logging.Info("Custom database migrations completed")
|
||||
@@ -132,7 +128,6 @@ func seedCarModels(db *gorm.DB) error {
|
||||
carModels := []model.CarModel{
|
||||
{Value: 0, CarModel: "Porsche 991 GT3 R"},
|
||||
{Value: 1, CarModel: "Mercedes-AMG GT3"},
|
||||
// ... Add all car models from your list
|
||||
}
|
||||
|
||||
for _, cm := range carModels {
|
||||
|
||||
7
local/utl/env/env.go
vendored
7
local/utl/env/env.go
vendored
@@ -6,12 +6,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Default paths for when environment variables are not set
|
||||
DefaultSteamCMDPath = "c:\\steamcmd\\steamcmd.exe"
|
||||
DefaultNSSMPath = ".\\nssm.exe"
|
||||
)
|
||||
|
||||
// GetSteamCMDPath returns the SteamCMD executable path from environment variable or default
|
||||
func GetSteamCMDPath() string {
|
||||
if path := os.Getenv("STEAMCMD_PATH"); path != "" {
|
||||
return path
|
||||
@@ -19,13 +17,11 @@ func GetSteamCMDPath() string {
|
||||
return DefaultSteamCMDPath
|
||||
}
|
||||
|
||||
// GetSteamCMDDirPath returns the directory containing SteamCMD executable
|
||||
func GetSteamCMDDirPath() string {
|
||||
steamCMDPath := GetSteamCMDPath()
|
||||
return filepath.Dir(steamCMDPath)
|
||||
}
|
||||
|
||||
// GetNSSMPath returns the NSSM executable path from environment variable or default
|
||||
func GetNSSMPath() string {
|
||||
if path := os.Getenv("NSSM_PATH"); path != "" {
|
||||
return path
|
||||
@@ -33,17 +29,14 @@ func GetNSSMPath() string {
|
||||
return DefaultNSSMPath
|
||||
}
|
||||
|
||||
// ValidatePaths checks if the configured paths exist (optional validation)
|
||||
func ValidatePaths() map[string]error {
|
||||
errors := make(map[string]error)
|
||||
|
||||
// Check SteamCMD path
|
||||
steamCMDPath := GetSteamCMDPath()
|
||||
if _, err := os.Stat(steamCMDPath); os.IsNotExist(err) {
|
||||
errors["STEAMCMD_PATH"] = err
|
||||
}
|
||||
|
||||
// Check NSSM path
|
||||
nssmPath := GetNSSMPath()
|
||||
if _, err := os.Stat(nssmPath); os.IsNotExist(err) {
|
||||
errors["NSSM_PATH"] = err
|
||||
|
||||
@@ -9,45 +9,37 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// ControllerErrorHandler provides centralized error handling for controllers
|
||||
type ControllerErrorHandler struct {
|
||||
errorLogger *logging.ErrorLogger
|
||||
}
|
||||
|
||||
// NewControllerErrorHandler creates a new controller error handler instance
|
||||
func NewControllerErrorHandler() *ControllerErrorHandler {
|
||||
return &ControllerErrorHandler{
|
||||
errorLogger: logging.GetErrorLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResponse represents a standardized error response
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code int `json:"code,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// HandleError handles controller errors with logging and standardized responses
|
||||
func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get caller information for logging
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
file = strings.TrimPrefix(file, "acc-server-manager/")
|
||||
|
||||
// Build context string
|
||||
contextStr := ""
|
||||
if len(context) > 0 {
|
||||
contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|"))
|
||||
}
|
||||
|
||||
// Clean error message (remove null bytes)
|
||||
cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "")
|
||||
|
||||
// Log the error with context
|
||||
ceh.errorLogger.LogWithContext(
|
||||
fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line),
|
||||
"%s%s",
|
||||
@@ -55,25 +47,39 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
|
||||
cleanErrorMsg,
|
||||
)
|
||||
|
||||
// Create standardized error response
|
||||
errorResponse := ErrorResponse{
|
||||
Error: cleanErrorMsg,
|
||||
Code: statusCode,
|
||||
}
|
||||
|
||||
// Add request details if available
|
||||
if c != nil {
|
||||
if errorResponse.Details == nil {
|
||||
errorResponse.Details = make(map[string]string)
|
||||
}
|
||||
errorResponse.Details["method"] = c.Method()
|
||||
errorResponse.Details["path"] = c.Path()
|
||||
errorResponse.Details["ip"] = c.IP()
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
errorResponse.Details["method"] = c.Method()
|
||||
errorResponse.Details["path"] = c.Path()
|
||||
|
||||
if ip := c.IP(); ip != "" {
|
||||
errorResponse.Details["ip"] = ip
|
||||
} else {
|
||||
errorResponse.Details["ip"] = "unknown"
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if c == nil {
|
||||
return fmt.Errorf("cannot return HTTP response: context is nil")
|
||||
}
|
||||
|
||||
// Return appropriate response based on status code
|
||||
if statusCode >= 500 {
|
||||
// For server errors, don't expose internal details
|
||||
return c.Status(statusCode).JSON(ErrorResponse{
|
||||
Error: "Internal server error",
|
||||
Code: statusCode,
|
||||
@@ -83,52 +89,42 @@ func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCo
|
||||
return c.Status(statusCode).JSON(errorResponse)
|
||||
}
|
||||
|
||||
// HandleValidationError handles validation errors specifically
|
||||
func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field)
|
||||
}
|
||||
|
||||
// HandleDatabaseError handles database-related errors
|
||||
func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE")
|
||||
}
|
||||
|
||||
// HandleAuthError handles authentication/authorization errors
|
||||
func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH")
|
||||
}
|
||||
|
||||
// HandleNotFoundError handles resource not found errors
|
||||
func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error {
|
||||
err := fmt.Errorf("%s not found", resource)
|
||||
return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND")
|
||||
}
|
||||
|
||||
// HandleBusinessLogicError handles business logic errors
|
||||
func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC")
|
||||
}
|
||||
|
||||
// HandleServiceError handles service layer errors
|
||||
func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE")
|
||||
}
|
||||
|
||||
// HandleParsingError handles request parsing errors
|
||||
func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING")
|
||||
}
|
||||
|
||||
// HandleUUIDError handles UUID parsing errors
|
||||
func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error {
|
||||
err := fmt.Errorf("invalid %s format", field)
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field)
|
||||
}
|
||||
|
||||
// Global controller error handler instance
|
||||
var globalErrorHandler *ControllerErrorHandler
|
||||
|
||||
// GetControllerErrorHandler returns the global controller error handler instance
|
||||
func GetControllerErrorHandler() *ControllerErrorHandler {
|
||||
if globalErrorHandler == nil {
|
||||
globalErrorHandler = NewControllerErrorHandler()
|
||||
@@ -136,49 +132,38 @@ func GetControllerErrorHandler() *ControllerErrorHandler {
|
||||
return globalErrorHandler
|
||||
}
|
||||
|
||||
// Convenience functions using the global error handler
|
||||
|
||||
// HandleError handles controller errors using the global error handler
|
||||
func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
|
||||
return GetControllerErrorHandler().HandleError(c, err, statusCode, context...)
|
||||
}
|
||||
|
||||
// HandleValidationError handles validation errors using the global error handler
|
||||
func HandleValidationError(c *fiber.Ctx, err error, field string) error {
|
||||
return GetControllerErrorHandler().HandleValidationError(c, err, field)
|
||||
}
|
||||
|
||||
// HandleDatabaseError handles database errors using the global error handler
|
||||
func HandleDatabaseError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleDatabaseError(c, err)
|
||||
}
|
||||
|
||||
// HandleAuthError handles auth errors using the global error handler
|
||||
func HandleAuthError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleAuthError(c, err)
|
||||
}
|
||||
|
||||
// HandleNotFoundError handles not found errors using the global error handler
|
||||
func HandleNotFoundError(c *fiber.Ctx, resource string) error {
|
||||
return GetControllerErrorHandler().HandleNotFoundError(c, resource)
|
||||
}
|
||||
|
||||
// HandleBusinessLogicError handles business logic errors using the global error handler
|
||||
func HandleBusinessLogicError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleBusinessLogicError(c, err)
|
||||
}
|
||||
|
||||
// HandleServiceError handles service errors using the global error handler
|
||||
func HandleServiceError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleServiceError(c, err)
|
||||
}
|
||||
|
||||
// HandleParsingError handles parsing errors using the global error handler
|
||||
func HandleParsingError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleParsingError(c, err)
|
||||
}
|
||||
|
||||
// HandleUUIDError handles UUID errors using the global error handler
|
||||
func HandleUUIDError(c *fiber.Ctx, field string) error {
|
||||
return GetControllerErrorHandler().HandleUUIDError(c, field)
|
||||
}
|
||||
|
||||
67
local/utl/errors/safe_error.go
Normal file
67
local/utl/errors/safe_error.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
type SafeError struct {
|
||||
Message string
|
||||
Code int
|
||||
Fatal bool
|
||||
}
|
||||
|
||||
func (e *SafeError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func NewSafeError(message string, code int) *SafeError {
|
||||
return &SafeError{
|
||||
Message: message,
|
||||
Code: code,
|
||||
Fatal: false,
|
||||
}
|
||||
}
|
||||
|
||||
func NewFatalError(message string, code int) *SafeError {
|
||||
return &SafeError{
|
||||
Message: message,
|
||||
Code: code,
|
||||
Fatal: true,
|
||||
}
|
||||
}
|
||||
|
||||
func HandleError(err error, context string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if safeErr, ok := err.(*SafeError); ok {
|
||||
if safeErr.Fatal {
|
||||
logging.Error("Fatal error in %s: %s", context, safeErr.Message)
|
||||
if os.Getenv("ENVIRONMENT") == "production" {
|
||||
logging.Error("Application shutting down due to fatal error")
|
||||
os.Exit(safeErr.Code)
|
||||
} else {
|
||||
logging.Warn("Fatal error occurred but not exiting in non-production environment")
|
||||
}
|
||||
} else {
|
||||
logging.Error("Error in %s: %s", context, safeErr.Message)
|
||||
}
|
||||
} else {
|
||||
logging.Error("Unexpected error in %s: %v", context, err)
|
||||
}
|
||||
}
|
||||
|
||||
func SafeFatal(message string, args ...interface{}) {
|
||||
formattedMessage := fmt.Sprintf(message, args...)
|
||||
err := NewFatalError(formattedMessage, 1)
|
||||
HandleError(err, "application")
|
||||
}
|
||||
|
||||
func SafeLog(message string, args ...interface{}) {
|
||||
formattedMessage := fmt.Sprintf(message, args...)
|
||||
err := NewSafeError(formattedMessage, 0)
|
||||
HandleError(err, "application")
|
||||
}
|
||||
91
local/utl/graceful/shutdown.go
Normal file
91
local/utl/graceful/shutdown.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package graceful
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ShutdownManager struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
handlers []func() error
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
var globalManager *ShutdownManager
|
||||
var once sync.Once
|
||||
|
||||
func GetManager() *ShutdownManager {
|
||||
once.Do(func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
globalManager = &ShutdownManager{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
handlers: make([]func() error, 0),
|
||||
}
|
||||
|
||||
go globalManager.watchSignals()
|
||||
})
|
||||
return globalManager
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) watchSignals() {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
<-sigChan
|
||||
sm.Shutdown(30 * time.Second)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) AddHandler(handler func() error) {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
sm.handlers = append(sm.handlers, handler)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) Context() context.Context {
|
||||
return sm.ctx
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) AddGoroutine() {
|
||||
sm.wg.Add(1)
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) GoroutineDone() {
|
||||
sm.wg.Done()
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) RunGoroutine(fn func(ctx context.Context)) {
|
||||
sm.wg.Add(1)
|
||||
go func() {
|
||||
defer sm.wg.Done()
|
||||
fn(sm.ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (sm *ShutdownManager) Shutdown(timeout time.Duration) {
|
||||
sm.cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
sm.wg.Wait()
|
||||
|
||||
sm.mutex.Lock()
|
||||
for _, handler := range sm.handlers {
|
||||
handler()
|
||||
}
|
||||
sm.mutex.Unlock()
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
@@ -2,88 +2,100 @@ package jwt
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/errors"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
goerrors "errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// SecretKey holds the JWT signing key loaded from environment
|
||||
var SecretKey []byte
|
||||
|
||||
// Claims represents the JWT claims.
|
||||
type Claims struct {
|
||||
UserID string `json:"user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
IsOpenToken bool `json:"is_open_token"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// init initializes the JWT secret key from environment variable
|
||||
func Init() {
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Fatal("JWT_SECRET environment variable is required and cannot be empty")
|
||||
}
|
||||
type JWTHandler struct {
|
||||
SecretKey []byte
|
||||
IsOpenToken bool
|
||||
}
|
||||
|
||||
// Decode base64 secret if it looks like base64, otherwise use as-is
|
||||
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
|
||||
SecretKey = decoded
|
||||
} else {
|
||||
SecretKey = []byte(jwtSecret)
|
||||
}
|
||||
type OpenJWTHandler struct {
|
||||
*JWTHandler
|
||||
}
|
||||
|
||||
// Ensure minimum key length for security
|
||||
if len(SecretKey) < 32 {
|
||||
log.Fatal("JWT_SECRET must be at least 32 bytes long for security")
|
||||
func NewOpenJWTHandler(jwtSecret string) *OpenJWTHandler {
|
||||
jwtHandler := NewJWTHandler(jwtSecret)
|
||||
jwtHandler.IsOpenToken = true
|
||||
return &OpenJWTHandler{
|
||||
JWTHandler: jwtHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSecretKey generates a cryptographically secure random key for JWT signing
|
||||
// This is a utility function for generating new secrets, not used in normal operation
|
||||
func GenerateSecretKey() string {
|
||||
func NewJWTHandler(jwtSecret string) *JWTHandler {
|
||||
if jwtSecret == "" {
|
||||
errors.SafeFatal("JWT_SECRET environment variable is required and cannot be empty")
|
||||
}
|
||||
|
||||
var secretKey []byte
|
||||
|
||||
if decoded, err := base64.StdEncoding.DecodeString(jwtSecret); err == nil && len(decoded) >= 32 {
|
||||
secretKey = decoded
|
||||
} else {
|
||||
secretKey = []byte(jwtSecret)
|
||||
}
|
||||
|
||||
if len(secretKey) < 32 {
|
||||
errors.SafeFatal("JWT_SECRET must be at least 32 bytes long for security")
|
||||
}
|
||||
return &JWTHandler{
|
||||
SecretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (jh *JWTHandler) GenerateSecretKey() string {
|
||||
key := make([]byte, 64) // 512 bits
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
log.Fatal("Failed to generate random key: ", err)
|
||||
errors.SafeFatal("Failed to generate random key: %v", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key)
|
||||
}
|
||||
|
||||
// GenerateToken generates a new JWT for a given user.
|
||||
func GenerateToken(user *model.User) (string, error) {
|
||||
func (jh *JWTHandler) GenerateToken(userId string) (string, error) {
|
||||
expirationTime := time.Now().Add(24 * time.Hour)
|
||||
claims := &Claims{
|
||||
UserID: user.ID.String(),
|
||||
UserID: userId,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expirationTime),
|
||||
},
|
||||
IsOpenToken: jh.IsOpenToken,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(SecretKey)
|
||||
return token.SignedString(jh.SecretKey)
|
||||
}
|
||||
|
||||
func GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error) {
|
||||
func (jh *JWTHandler) GenerateTokenWithExpiry(user *model.User, expiry time.Time) (string, error) {
|
||||
expirationTime := expiry
|
||||
claims := &Claims{
|
||||
UserID: user.ID.String(),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expirationTime),
|
||||
},
|
||||
IsOpenToken: jh.IsOpenToken,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(SecretKey)
|
||||
return token.SignedString(jh.SecretKey)
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT and returns the claims if the token is valid.
|
||||
func ValidateToken(tokenString string) (*Claims, error) {
|
||||
func (jh *JWTHandler) ValidateToken(tokenString string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return SecretKey, nil
|
||||
return jh.SecretKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -91,7 +103,7 @@ func ValidateToken(tokenString string) (*Claims, error) {
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
return nil, goerrors.New("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
|
||||
@@ -15,7 +15,6 @@ var (
|
||||
timeFormat = "2006-01-02 15:04:05.000"
|
||||
)
|
||||
|
||||
// BaseLogger provides the core logging functionality
|
||||
type BaseLogger struct {
|
||||
file *os.File
|
||||
logger *log.Logger
|
||||
@@ -23,7 +22,6 @@ type BaseLogger struct {
|
||||
initialized bool
|
||||
}
|
||||
|
||||
// LogLevel represents different logging levels
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
@@ -34,28 +32,23 @@ const (
|
||||
LogLevelPanic LogLevel = "PANIC"
|
||||
)
|
||||
|
||||
// Initialize creates a new base logger instance
|
||||
func InitializeBase(tp string) (*BaseLogger, error) {
|
||||
return newBaseLogger(tp)
|
||||
}
|
||||
|
||||
func newBaseLogger(tp string) (*BaseLogger, error) {
|
||||
// Ensure logs directory exists
|
||||
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create logs directory: %v", err)
|
||||
}
|
||||
|
||||
// Open log file with date in name
|
||||
logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp))
|
||||
file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open log file: %v", err)
|
||||
}
|
||||
|
||||
// Create multi-writer for both file and console
|
||||
multiWriter := io.MultiWriter(file, os.Stdout)
|
||||
|
||||
// Create base logger
|
||||
logger := &BaseLogger{
|
||||
file: file,
|
||||
logger: log.New(multiWriter, "", 0),
|
||||
@@ -65,13 +58,11 @@ func newBaseLogger(tp string) (*BaseLogger, error) {
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// GetBaseLogger creates and returns a new base logger instance
|
||||
func GetBaseLogger(tp string) *BaseLogger {
|
||||
baseLogger, _ := InitializeBase(tp)
|
||||
return baseLogger
|
||||
}
|
||||
|
||||
// Close closes the log file
|
||||
func (bl *BaseLogger) Close() error {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
@@ -82,7 +73,6 @@ func (bl *BaseLogger) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Log writes a log entry with the specified level
|
||||
func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
if bl == nil || !bl.initialized {
|
||||
return
|
||||
@@ -91,14 +81,11 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Get caller info (skip 2 frames: this function and the calling Log function)
|
||||
_, file, line, _ := runtime.Caller(2)
|
||||
file = filepath.Base(file)
|
||||
|
||||
// Format message
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Format final log line
|
||||
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
|
||||
time.Now().Format(timeFormat),
|
||||
string(level),
|
||||
@@ -110,7 +97,6 @@ func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
bl.logger.Println(logLine)
|
||||
}
|
||||
|
||||
// LogWithCaller writes a log entry with custom caller depth
|
||||
func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) {
|
||||
if bl == nil || !bl.initialized {
|
||||
return
|
||||
@@ -119,14 +105,11 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Get caller info with custom depth
|
||||
_, file, line, _ := runtime.Caller(callerDepth)
|
||||
file = filepath.Base(file)
|
||||
|
||||
// Format message
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Format final log line
|
||||
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
|
||||
time.Now().Format(timeFormat),
|
||||
string(level),
|
||||
@@ -138,7 +121,6 @@ func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format stri
|
||||
bl.logger.Println(logLine)
|
||||
}
|
||||
|
||||
// IsInitialized returns whether the base logger is initialized
|
||||
func (bl *BaseLogger) IsInitialized() bool {
|
||||
if bl == nil {
|
||||
return false
|
||||
@@ -148,19 +130,16 @@ func (bl *BaseLogger) IsInitialized() bool {
|
||||
return bl.initialized
|
||||
}
|
||||
|
||||
// RecoverAndLog recovers from panics and logs them
|
||||
func RecoverAndLog() {
|
||||
baseLogger := GetBaseLogger("panic")
|
||||
if baseLogger != nil && baseLogger.IsInitialized() {
|
||||
if r := recover(); r != nil {
|
||||
// Get stack trace
|
||||
buf := make([]byte, 4096)
|
||||
n := runtime.Stack(buf, false)
|
||||
stackTrace := string(buf[:n])
|
||||
|
||||
baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace)
|
||||
|
||||
// Re-panic to maintain original behavior if needed
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DebugLogger handles debug-level logging
|
||||
type DebugLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewDebugLogger creates a new debug logger instance
|
||||
func NewDebugLogger() *DebugLogger {
|
||||
base, _ := InitializeBase("debug")
|
||||
return &DebugLogger{
|
||||
@@ -19,14 +17,12 @@ func NewDebugLogger() *DebugLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes a debug-level log entry
|
||||
func (dl *DebugLogger) Log(format string, v ...interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes a debug-level log entry with additional context
|
||||
func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if dl.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -34,7 +30,6 @@ func (dl *DebugLogger) LogWithContext(context string, format string, v ...interf
|
||||
}
|
||||
}
|
||||
|
||||
// LogFunction logs function entry and exit for debugging
|
||||
func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
|
||||
if dl.base != nil {
|
||||
if len(args) > 0 {
|
||||
@@ -45,21 +40,18 @@ func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogVariable logs variable values for debugging
|
||||
func (dl *DebugLogger) LogVariable(varName string, value interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value)
|
||||
}
|
||||
}
|
||||
|
||||
// LogState logs application state information
|
||||
func (dl *DebugLogger) LogState(component string, state interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state)
|
||||
}
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries for debugging
|
||||
func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
|
||||
if dl.base != nil {
|
||||
if len(args) > 0 {
|
||||
@@ -70,7 +62,6 @@ func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogMemory logs memory usage information
|
||||
func (dl *DebugLogger) LogMemory() {
|
||||
if dl.base != nil {
|
||||
var m runtime.MemStats
|
||||
@@ -80,32 +71,27 @@ func (dl *DebugLogger) LogMemory() {
|
||||
}
|
||||
}
|
||||
|
||||
// LogGoroutines logs current number of goroutines
|
||||
func (dl *DebugLogger) LogGoroutines() {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine())
|
||||
}
|
||||
}
|
||||
|
||||
// LogTiming logs timing information for performance debugging
|
||||
func (dl *DebugLogger) LogTiming(operation string, duration interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to convert bytes to kilobytes
|
||||
func bToKb(b uint64) uint64 {
|
||||
return b / 1024
|
||||
}
|
||||
|
||||
// Global debug logger instance
|
||||
var (
|
||||
debugLogger *DebugLogger
|
||||
debugOnce sync.Once
|
||||
)
|
||||
|
||||
// GetDebugLogger returns the global debug logger instance
|
||||
func GetDebugLogger() *DebugLogger {
|
||||
debugOnce.Do(func() {
|
||||
debugLogger = NewDebugLogger()
|
||||
@@ -113,47 +99,38 @@ func GetDebugLogger() *DebugLogger {
|
||||
return debugLogger
|
||||
}
|
||||
|
||||
// Debug logs a debug-level message using the global debug logger
|
||||
func Debug(format string, v ...interface{}) {
|
||||
GetDebugLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// DebugWithContext logs a debug-level message with context using the global debug logger
|
||||
func DebugWithContext(context string, format string, v ...interface{}) {
|
||||
GetDebugLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// DebugFunction logs function entry and exit using the global debug logger
|
||||
func DebugFunction(functionName string, args ...interface{}) {
|
||||
GetDebugLogger().LogFunction(functionName, args...)
|
||||
}
|
||||
|
||||
// DebugVariable logs variable values using the global debug logger
|
||||
func DebugVariable(varName string, value interface{}) {
|
||||
GetDebugLogger().LogVariable(varName, value)
|
||||
}
|
||||
|
||||
// DebugState logs application state information using the global debug logger
|
||||
func DebugState(component string, state interface{}) {
|
||||
GetDebugLogger().LogState(component, state)
|
||||
}
|
||||
|
||||
// DebugSQL logs SQL queries using the global debug logger
|
||||
func DebugSQL(query string, args ...interface{}) {
|
||||
GetDebugLogger().LogSQL(query, args...)
|
||||
}
|
||||
|
||||
// DebugMemory logs memory usage information using the global debug logger
|
||||
func DebugMemory() {
|
||||
GetDebugLogger().LogMemory()
|
||||
}
|
||||
|
||||
// DebugGoroutines logs current number of goroutines using the global debug logger
|
||||
func DebugGoroutines() {
|
||||
GetDebugLogger().LogGoroutines()
|
||||
}
|
||||
|
||||
// DebugTiming logs timing information using the global debug logger
|
||||
func DebugTiming(operation string, duration interface{}) {
|
||||
GetDebugLogger().LogTiming(operation, duration)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrorLogger handles error-level logging
|
||||
type ErrorLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewErrorLogger creates a new error logger instance
|
||||
func NewErrorLogger() *ErrorLogger {
|
||||
base, _ := InitializeBase("error")
|
||||
return &ErrorLogger{
|
||||
@@ -19,14 +17,12 @@ func NewErrorLogger() *ErrorLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes an error-level log entry
|
||||
func (el *ErrorLogger) Log(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
el.base.Log(LogLevelError, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes an error-level log entry with additional context
|
||||
func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -34,7 +30,6 @@ func (el *ErrorLogger) LogWithContext(context string, format string, v ...interf
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error object with optional message
|
||||
func (el *ErrorLogger) LogError(err error, message ...string) {
|
||||
if el.base != nil && err != nil {
|
||||
if len(message) > 0 {
|
||||
@@ -45,7 +40,6 @@ func (el *ErrorLogger) LogError(err error, message ...string) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithStackTrace logs an error with stack trace
|
||||
func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
// Get stack trace
|
||||
@@ -58,7 +52,6 @@ func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogFatal logs a fatal error and exits the program
|
||||
func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
el.base.Log(LogLevelError, "[FATAL] "+format, v...)
|
||||
@@ -66,13 +59,11 @@ func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// Global error logger instance
|
||||
var (
|
||||
errorLogger *ErrorLogger
|
||||
errorOnce sync.Once
|
||||
)
|
||||
|
||||
// GetErrorLogger returns the global error logger instance
|
||||
func GetErrorLogger() *ErrorLogger {
|
||||
errorOnce.Do(func() {
|
||||
errorLogger = NewErrorLogger()
|
||||
@@ -80,27 +71,22 @@ func GetErrorLogger() *ErrorLogger {
|
||||
return errorLogger
|
||||
}
|
||||
|
||||
// Error logs an error-level message using the global error logger
|
||||
func Error(format string, v ...interface{}) {
|
||||
GetErrorLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// ErrorWithContext logs an error-level message with context using the global error logger
|
||||
func ErrorWithContext(context string, format string, v ...interface{}) {
|
||||
GetErrorLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// LogError logs an error object using the global error logger
|
||||
func LogError(err error, message ...string) {
|
||||
GetErrorLogger().LogError(err, message...)
|
||||
}
|
||||
|
||||
// ErrorWithStackTrace logs an error with stack trace using the global error logger
|
||||
func ErrorWithStackTrace(format string, v ...interface{}) {
|
||||
GetErrorLogger().LogWithStackTrace(format, v...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal error and exits the program using the global error logger
|
||||
func Fatal(format string, v ...interface{}) {
|
||||
GetErrorLogger().LogFatal(format, v...)
|
||||
}
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// InfoLogger handles info-level logging
|
||||
type InfoLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewInfoLogger creates a new info logger instance
|
||||
func NewInfoLogger() *InfoLogger {
|
||||
base, _ := InitializeBase("info")
|
||||
return &InfoLogger{
|
||||
@@ -18,14 +16,12 @@ func NewInfoLogger() *InfoLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes an info-level log entry
|
||||
func (il *InfoLogger) Log(format string, v ...interface{}) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes an info-level log entry with additional context
|
||||
func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if il.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -33,55 +29,47 @@ func (il *InfoLogger) LogWithContext(context string, format string, v ...interfa
|
||||
}
|
||||
}
|
||||
|
||||
// LogStartup logs application startup information
|
||||
func (il *InfoLogger) LogStartup(component string, message string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogShutdown logs application shutdown information
|
||||
func (il *InfoLogger) LogShutdown(component string, message string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogOperation logs general operation information
|
||||
func (il *InfoLogger) LogOperation(operation string, details string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details)
|
||||
}
|
||||
}
|
||||
|
||||
// LogStatus logs status changes or updates
|
||||
func (il *InfoLogger) LogStatus(component string, status string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status)
|
||||
}
|
||||
}
|
||||
|
||||
// LogRequest logs incoming requests
|
||||
func (il *InfoLogger) LogRequest(method string, path string, userAgent string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// LogResponse logs outgoing responses
|
||||
func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Global info logger instance
|
||||
var (
|
||||
infoLogger *InfoLogger
|
||||
infoOnce sync.Once
|
||||
)
|
||||
|
||||
// GetInfoLogger returns the global info logger instance
|
||||
func GetInfoLogger() *InfoLogger {
|
||||
infoOnce.Do(func() {
|
||||
infoLogger = NewInfoLogger()
|
||||
@@ -89,42 +77,34 @@ func GetInfoLogger() *InfoLogger {
|
||||
return infoLogger
|
||||
}
|
||||
|
||||
// Info logs an info-level message using the global info logger
|
||||
func Info(format string, v ...interface{}) {
|
||||
GetInfoLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// InfoWithContext logs an info-level message with context using the global info logger
|
||||
func InfoWithContext(context string, format string, v ...interface{}) {
|
||||
GetInfoLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// InfoStartup logs application startup information using the global info logger
|
||||
func InfoStartup(component string, message string) {
|
||||
GetInfoLogger().LogStartup(component, message)
|
||||
}
|
||||
|
||||
// InfoShutdown logs application shutdown information using the global info logger
|
||||
func InfoShutdown(component string, message string) {
|
||||
GetInfoLogger().LogShutdown(component, message)
|
||||
}
|
||||
|
||||
// InfoOperation logs general operation information using the global info logger
|
||||
func InfoOperation(operation string, details string) {
|
||||
GetInfoLogger().LogOperation(operation, details)
|
||||
}
|
||||
|
||||
// InfoStatus logs status changes or updates using the global info logger
|
||||
func InfoStatus(component string, status string) {
|
||||
GetInfoLogger().LogStatus(component, status)
|
||||
}
|
||||
|
||||
// InfoRequest logs incoming requests using the global info logger
|
||||
func InfoRequest(method string, path string, userAgent string) {
|
||||
GetInfoLogger().LogRequest(method, path, userAgent)
|
||||
}
|
||||
|
||||
// InfoResponse logs outgoing responses using the global info logger
|
||||
func InfoResponse(method string, path string, statusCode int, duration string) {
|
||||
GetInfoLogger().LogResponse(method, path, statusCode, duration)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// Legacy logger for backward compatibility
|
||||
logger *Logger
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// Logger maintains backward compatibility with existing code
|
||||
type Logger struct {
|
||||
base *BaseLogger
|
||||
errorLogger *ErrorLogger
|
||||
@@ -20,8 +18,6 @@ type Logger struct {
|
||||
debugLogger *DebugLogger
|
||||
}
|
||||
|
||||
// Initialize creates or gets the singleton logger instance
|
||||
// This maintains backward compatibility with existing code
|
||||
func Initialize() (*Logger, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
@@ -31,13 +27,11 @@ func Initialize() (*Logger, error) {
|
||||
}
|
||||
|
||||
func newLogger() (*Logger, error) {
|
||||
// Initialize the base logger
|
||||
baseLogger, err := InitializeBase("log")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the legacy logger wrapper
|
||||
logger := &Logger{
|
||||
base: baseLogger,
|
||||
errorLogger: GetErrorLogger(),
|
||||
@@ -49,7 +43,6 @@ func newLogger() (*Logger, error) {
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// Close closes the logger
|
||||
func (l *Logger) Close() error {
|
||||
if l.base != nil {
|
||||
return l.base.Close()
|
||||
@@ -57,7 +50,6 @@ func (l *Logger) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Legacy methods for backward compatibility
|
||||
func (l *Logger) log(level, format string, v ...interface{}) {
|
||||
if l.base != nil {
|
||||
l.base.LogWithCaller(LogLevel(level), 3, format, v...)
|
||||
@@ -94,13 +86,10 @@ func (l *Logger) Panic(format string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Global convenience functions for backward compatibility
|
||||
// These are now implemented in individual logger files to avoid redeclaration
|
||||
func LegacyInfo(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Info(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetInfoLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -109,7 +98,6 @@ func LegacyError(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Error(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetErrorLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -118,7 +106,6 @@ func LegacyWarn(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Warn(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetWarnLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -127,7 +114,6 @@ func LegacyDebug(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Debug(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetDebugLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
@@ -136,55 +122,42 @@ func Panic(format string) {
|
||||
if logger != nil {
|
||||
logger.Panic(format)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetErrorLogger().LogFatal(format)
|
||||
}
|
||||
}
|
||||
|
||||
// Enhanced logging convenience functions
|
||||
// These provide direct access to specialized logging functions
|
||||
|
||||
// LogStartup logs application startup information
|
||||
func LogStartup(component string, message string) {
|
||||
GetInfoLogger().LogStartup(component, message)
|
||||
}
|
||||
|
||||
// LogShutdown logs application shutdown information
|
||||
func LogShutdown(component string, message string) {
|
||||
GetInfoLogger().LogShutdown(component, message)
|
||||
}
|
||||
|
||||
// LogOperation logs general operation information
|
||||
func LogOperation(operation string, details string) {
|
||||
GetInfoLogger().LogOperation(operation, details)
|
||||
}
|
||||
|
||||
// LogRequest logs incoming HTTP requests
|
||||
func LogRequest(method string, path string, userAgent string) {
|
||||
GetInfoLogger().LogRequest(method, path, userAgent)
|
||||
}
|
||||
|
||||
// LogResponse logs outgoing HTTP responses
|
||||
func LogResponse(method string, path string, statusCode int, duration string) {
|
||||
GetInfoLogger().LogResponse(method, path, statusCode, duration)
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries for debugging
|
||||
func LogSQL(query string, args ...interface{}) {
|
||||
GetDebugLogger().LogSQL(query, args...)
|
||||
}
|
||||
|
||||
// LogMemory logs memory usage information
|
||||
func LogMemory() {
|
||||
GetDebugLogger().LogMemory()
|
||||
}
|
||||
|
||||
// LogTiming logs timing information for performance debugging
|
||||
func LogTiming(operation string, duration interface{}) {
|
||||
GetDebugLogger().LogTiming(operation, duration)
|
||||
}
|
||||
|
||||
// GetLegacyLogger returns the legacy logger instance for backward compatibility
|
||||
func GetLegacyLogger() *Logger {
|
||||
if logger == nil {
|
||||
logger, _ = Initialize()
|
||||
@@ -192,21 +165,17 @@ func GetLegacyLogger() *Logger {
|
||||
return logger
|
||||
}
|
||||
|
||||
// InitializeLogging initializes all logging components
|
||||
func InitializeLogging() error {
|
||||
// Initialize legacy logger for backward compatibility
|
||||
_, err := Initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize legacy logger: %v", err)
|
||||
}
|
||||
|
||||
// Pre-initialize all logger types to ensure separate log files
|
||||
GetErrorLogger()
|
||||
GetWarnLogger()
|
||||
GetInfoLogger()
|
||||
GetDebugLogger()
|
||||
|
||||
// Log successful initialization
|
||||
Info("Logging system initialized successfully")
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// WarnLogger handles warn-level logging
|
||||
type WarnLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewWarnLogger creates a new warn logger instance
|
||||
func NewWarnLogger() *WarnLogger {
|
||||
base, _ := InitializeBase("warn")
|
||||
return &WarnLogger{
|
||||
@@ -18,14 +16,12 @@ func NewWarnLogger() *WarnLogger {
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes a warn-level log entry
|
||||
func (wl *WarnLogger) Log(format string, v ...interface{}) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes a warn-level log entry with additional context
|
||||
func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if wl.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
@@ -33,7 +29,6 @@ func (wl *WarnLogger) LogWithContext(context string, format string, v ...interfa
|
||||
}
|
||||
}
|
||||
|
||||
// LogDeprecation logs a deprecation warning
|
||||
func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
|
||||
if wl.base != nil {
|
||||
if alternative != "" {
|
||||
@@ -44,27 +39,23 @@ func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
|
||||
}
|
||||
}
|
||||
|
||||
// LogConfiguration logs configuration-related warnings
|
||||
func (wl *WarnLogger) LogConfiguration(setting string, message string) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogPerformance logs performance-related warnings
|
||||
func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Global warn logger instance
|
||||
var (
|
||||
warnLogger *WarnLogger
|
||||
warnOnce sync.Once
|
||||
)
|
||||
|
||||
// GetWarnLogger returns the global warn logger instance
|
||||
func GetWarnLogger() *WarnLogger {
|
||||
warnOnce.Do(func() {
|
||||
warnLogger = NewWarnLogger()
|
||||
@@ -72,27 +63,22 @@ func GetWarnLogger() *WarnLogger {
|
||||
return warnLogger
|
||||
}
|
||||
|
||||
// Warn logs a warn-level message using the global warn logger
|
||||
func Warn(format string, v ...interface{}) {
|
||||
GetWarnLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// WarnWithContext logs a warn-level message with context using the global warn logger
|
||||
func WarnWithContext(context string, format string, v ...interface{}) {
|
||||
GetWarnLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// WarnDeprecation logs a deprecation warning using the global warn logger
|
||||
func WarnDeprecation(feature string, alternative string) {
|
||||
GetWarnLogger().LogDeprecation(feature, alternative)
|
||||
}
|
||||
|
||||
// WarnConfiguration logs configuration-related warnings using the global warn logger
|
||||
func WarnConfiguration(setting string, message string) {
|
||||
GetWarnLogger().LogConfiguration(setting, message)
|
||||
}
|
||||
|
||||
// WarnPerformance logs performance-related warnings using the global warn logger
|
||||
func WarnPerformance(operation string, threshold string, actual string) {
|
||||
GetWarnLogger().LogPerformance(operation, threshold, actual)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// IsPortAvailable checks if a port is available for both TCP and UDP
|
||||
func IsPortAvailable(port int) bool {
|
||||
return IsTCPPortAvailable(port) && IsUDPPortAvailable(port)
|
||||
}
|
||||
|
||||
// IsTCPPortAvailable checks if a TCP port is available
|
||||
func IsTCPPortAvailable(port int) bool {
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
@@ -22,7 +20,6 @@ func IsTCPPortAvailable(port int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsUDPPortAvailable checks if a UDP port is available
|
||||
func IsUDPPortAvailable(port int) bool {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
@@ -32,7 +29,6 @@ func IsUDPPortAvailable(port int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// FindAvailablePort finds an available port starting from the given port
|
||||
func FindAvailablePort(startPort int) (int, error) {
|
||||
maxPort := 65535
|
||||
for port := startPort; port <= maxPort; port++ {
|
||||
@@ -43,14 +39,12 @@ func FindAvailablePort(startPort int) (int, error) {
|
||||
return 0, fmt.Errorf("no available ports found between %d and %d", startPort, maxPort)
|
||||
}
|
||||
|
||||
// FindAvailablePortRange finds a range of consecutive available ports
|
||||
func FindAvailablePortRange(startPort, count int) ([]int, error) {
|
||||
maxPort := 65535
|
||||
ports := make([]int, 0, count)
|
||||
currentPort := startPort
|
||||
|
||||
for len(ports) < count && currentPort <= maxPort {
|
||||
// Check if we have enough consecutive ports available
|
||||
available := true
|
||||
for i := 0; i < count-len(ports); i++ {
|
||||
if !IsPortAvailable(currentPort + i) {
|
||||
@@ -74,7 +68,6 @@ func FindAvailablePortRange(startPort, count int) ([]int, error) {
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
// WaitForPortAvailable waits for a port to become available with timeout
|
||||
func WaitForPortAvailable(port int, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
@@ -84,4 +77,4 @@ func WaitForPortAvailable(port int, timeout time.Duration) error {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for port %d to become available", port)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,13 +8,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// MinPasswordLength defines the minimum password length
|
||||
MinPasswordLength = 8
|
||||
// BcryptCost defines the cost factor for bcrypt hashing
|
||||
BcryptCost = 12
|
||||
BcryptCost = 12
|
||||
)
|
||||
|
||||
// HashPassword hashes a plain text password using bcrypt
|
||||
func HashPassword(password string) (string, error) {
|
||||
if len(password) < MinPasswordLength {
|
||||
return "", errors.New("password must be at least 8 characters long")
|
||||
@@ -28,12 +25,10 @@ func HashPassword(password string) (string, error) {
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// VerifyPassword verifies a plain text password against a hashed password
|
||||
func VerifyPassword(hashedPassword, password string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
}
|
||||
|
||||
// ValidatePasswordStrength validates password complexity requirements
|
||||
func ValidatePasswordStrength(password string) error {
|
||||
if len(password) < MinPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
|
||||
76
local/utl/security/download_verifier.go
Normal file
76
local/utl/security/download_verifier.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadVerifier struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewDownloadVerifier() *DownloadVerifier {
|
||||
return &DownloadVerifier{
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dv *DownloadVerifier) VerifyAndDownload(url, outputPath, expectedSHA256 string) error {
|
||||
if url == "" {
|
||||
return fmt.Errorf("URL cannot be empty")
|
||||
}
|
||||
if outputPath == "" {
|
||||
return fmt.Errorf("output path cannot be empty")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "ACC-Server-Manager/1.0")
|
||||
|
||||
resp, err := dv.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
writer := io.MultiWriter(file, hash)
|
||||
|
||||
_, err = io.Copy(writer, resp.Body)
|
||||
if err != nil {
|
||||
os.Remove(outputPath)
|
||||
return fmt.Errorf("failed to write file: %v", err)
|
||||
}
|
||||
|
||||
if expectedSHA256 != "" {
|
||||
actualHash := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
if actualHash != expectedSHA256 {
|
||||
os.Remove(outputPath)
|
||||
return fmt.Errorf("file hash mismatch: expected %s, got %s", expectedSHA256, actualHash)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
94
local/utl/security/path_validator.go
Normal file
94
local/utl/security/path_validator.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type PathValidator struct {
|
||||
allowedBasePaths []string
|
||||
blockedPatterns []*regexp.Regexp
|
||||
}
|
||||
|
||||
func NewPathValidator() *PathValidator {
|
||||
blockedPatterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`\.\.`),
|
||||
regexp.MustCompile(`^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])$`),
|
||||
regexp.MustCompile(`\x00`),
|
||||
regexp.MustCompile(`^\\\\`),
|
||||
regexp.MustCompile(`^[a-zA-Z]:\\Windows`),
|
||||
regexp.MustCompile(`^[a-zA-Z]:\\Program Files`),
|
||||
}
|
||||
|
||||
return &PathValidator{
|
||||
allowedBasePaths: []string{
|
||||
`C:\ACC-Servers`,
|
||||
`D:\ACC-Servers`,
|
||||
`E:\ACC-Servers`,
|
||||
`C:\SteamCMD`,
|
||||
`D:\SteamCMD`,
|
||||
`E:\SteamCMD`,
|
||||
},
|
||||
blockedPatterns: blockedPatterns,
|
||||
}
|
||||
}
|
||||
|
||||
func (pv *PathValidator) ValidateInstallPath(path string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
|
||||
cleanPath := filepath.Clean(path)
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path: %v", err)
|
||||
}
|
||||
|
||||
for _, pattern := range pv.blockedPatterns {
|
||||
if pattern.MatchString(absPath) || pattern.MatchString(strings.ToUpper(filepath.Base(absPath))) {
|
||||
return fmt.Errorf("path contains forbidden patterns")
|
||||
}
|
||||
}
|
||||
|
||||
allowed := false
|
||||
for _, basePath := range pv.allowedBasePaths {
|
||||
if strings.HasPrefix(strings.ToLower(absPath), strings.ToLower(basePath)) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return fmt.Errorf("path must be within allowed directories: %v", pv.allowedBasePaths)
|
||||
}
|
||||
|
||||
if len(absPath) > 260 {
|
||||
return fmt.Errorf("path too long (max 260 characters)")
|
||||
}
|
||||
|
||||
parentDir := filepath.Dir(absPath)
|
||||
if parentInfo, err := os.Stat(parentDir); err == nil {
|
||||
if !parentInfo.IsDir() {
|
||||
return fmt.Errorf("parent path is not a directory")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pv *PathValidator) AddAllowedBasePath(path string) error {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base path: %v", err)
|
||||
}
|
||||
|
||||
pv.allowedBasePaths = append(pv.allowedBasePaths, absPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pv *PathValidator) GetAllowedBasePaths() []string {
|
||||
return append([]string(nil), pv.allowedBasePaths...)
|
||||
}
|
||||
@@ -17,30 +17,29 @@ import (
|
||||
func Start(di *dig.Container) *fiber.App {
|
||||
app := fiber.New(fiber.Config{
|
||||
EnablePrintRoutes: true,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
BodyLimit: 10 * 1024 * 1024, // 10MB
|
||||
ReadTimeout: 20 * time.Minute,
|
||||
WriteTimeout: 20 * time.Minute,
|
||||
IdleTimeout: 25 * time.Minute,
|
||||
BodyLimit: 10 * 1024 * 1024,
|
||||
})
|
||||
|
||||
// Initialize security middleware
|
||||
securityMW := security.NewSecurityMiddleware()
|
||||
|
||||
// Add security middleware stack
|
||||
app.Use(securityMW.SecurityHeaders())
|
||||
app.Use(securityMW.LogSecurityEvents())
|
||||
app.Use(securityMW.TimeoutMiddleware(30 * time.Second))
|
||||
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
|
||||
app.Use(securityMW.TimeoutMiddleware(20 * time.Minute))
|
||||
app.Use(securityMW.RequestContextTimeout(20 * time.Minute))
|
||||
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024))
|
||||
app.Use(securityMW.ValidateUserAgent())
|
||||
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
|
||||
app.Use(securityMW.InputSanitization())
|
||||
app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global
|
||||
app.Use(securityMW.RateLimit(100, 1*time.Minute))
|
||||
|
||||
app.Use(helmet.New())
|
||||
|
||||
allowedOrigin := os.Getenv("CORS_ALLOWED_ORIGIN")
|
||||
if allowedOrigin == "" {
|
||||
allowedOrigin = "http://localhost:5173"
|
||||
allowedOrigin = "http://localhost:3000"
|
||||
}
|
||||
|
||||
app.Use(cors.New(cors.Config{
|
||||
@@ -61,7 +60,7 @@ func Start(di *dig.Container) *fiber.App {
|
||||
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "3000" // Default port
|
||||
port = "3000"
|
||||
}
|
||||
|
||||
logging.Info("Starting server on port %s", port)
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
)
|
||||
|
||||
type LogTailer struct {
|
||||
filePath string
|
||||
handleLine func(string)
|
||||
stopChan chan struct{}
|
||||
isRunning bool
|
||||
tracker *PositionTracker
|
||||
filePath string
|
||||
handleLine func(string)
|
||||
stopChan chan struct{}
|
||||
isRunning bool
|
||||
tracker *PositionTracker
|
||||
}
|
||||
|
||||
func NewLogTailer(filePath string, handleLine func(string)) *LogTailer {
|
||||
@@ -30,10 +30,9 @@ func (t *LogTailer) Start() {
|
||||
t.isRunning = true
|
||||
|
||||
go func() {
|
||||
// Load last position from tracker
|
||||
pos, err := t.tracker.LoadPosition()
|
||||
if err != nil {
|
||||
pos = &LogPosition{} // Start from beginning if error
|
||||
pos = &LogPosition{}
|
||||
}
|
||||
lastSize := pos.LastPosition
|
||||
|
||||
@@ -43,7 +42,6 @@ func (t *LogTailer) Start() {
|
||||
t.isRunning = false
|
||||
return
|
||||
default:
|
||||
// Try to open and read the file
|
||||
if file, err := os.Open(t.filePath); err == nil {
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
@@ -52,12 +50,10 @@ func (t *LogTailer) Start() {
|
||||
continue
|
||||
}
|
||||
|
||||
// If file was truncated, start from beginning
|
||||
if stat.Size() < lastSize {
|
||||
lastSize = 0
|
||||
}
|
||||
|
||||
// Seek to last read position
|
||||
if lastSize > 0 {
|
||||
file.Seek(lastSize, 0)
|
||||
}
|
||||
@@ -66,9 +62,8 @@ func (t *LogTailer) Start() {
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
t.handleLine(line)
|
||||
lastSize, _ = file.Seek(0, 1) // Get current position
|
||||
|
||||
// Save position periodically
|
||||
lastSize, _ = file.Seek(0, 1)
|
||||
|
||||
t.tracker.SavePosition(&LogPosition{
|
||||
LastPosition: lastSize,
|
||||
LastRead: line,
|
||||
@@ -78,7 +73,6 @@ func (t *LogTailer) Start() {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
// Wait before next attempt
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
@@ -90,4 +84,4 @@ func (t *LogTailer) Stop() {
|
||||
return
|
||||
}
|
||||
close(t.stopChan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
)
|
||||
|
||||
type LogPosition struct {
|
||||
LastPosition int64 `json:"last_position"`
|
||||
LastRead string `json:"last_read"`
|
||||
LastPosition int64 `json:"last_position"`
|
||||
LastRead string `json:"last_read"`
|
||||
}
|
||||
|
||||
type PositionTracker struct {
|
||||
@@ -16,11 +16,10 @@ type PositionTracker struct {
|
||||
}
|
||||
|
||||
func NewPositionTracker(logPath string) *PositionTracker {
|
||||
// Create position file in same directory as log file
|
||||
dir := filepath.Dir(logPath)
|
||||
base := filepath.Base(logPath)
|
||||
positionFile := filepath.Join(dir, "."+base+".position")
|
||||
|
||||
|
||||
return &PositionTracker{
|
||||
positionFile: positionFile,
|
||||
}
|
||||
@@ -30,7 +29,6 @@ func (t *PositionTracker) LoadPosition() (*LogPosition, error) {
|
||||
data, err := os.ReadFile(t.positionFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// Return empty position if file doesn't exist
|
||||
return &LogPosition{}, nil
|
||||
}
|
||||
return nil, err
|
||||
@@ -51,4 +49,4 @@ func (t *PositionTracker) SavePosition(pos *LogPosition) error {
|
||||
}
|
||||
|
||||
return os.WriteFile(t.positionFile, data, 0644)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,102 +11,104 @@ import (
|
||||
)
|
||||
|
||||
type StateChange int
|
||||
|
||||
const (
|
||||
PlayerCount StateChange = iota
|
||||
Session
|
||||
PlayerCount StateChange = iota
|
||||
Session
|
||||
)
|
||||
|
||||
var StateChanges = map[StateChange]string {
|
||||
PlayerCount: "player-count",
|
||||
Session: "session",
|
||||
var StateChanges = map[StateChange]string{
|
||||
PlayerCount: "player-count",
|
||||
Session: "session",
|
||||
}
|
||||
|
||||
type AccServerInstance struct {
|
||||
Model *model.Server
|
||||
State *model.ServerState
|
||||
OnStateChange func(*model.ServerState, ...StateChange)
|
||||
Model *model.Server
|
||||
State *model.ServerState
|
||||
OnStateChange func(*model.ServerState, ...StateChange)
|
||||
}
|
||||
|
||||
func NewAccServerInstance(server *model.Server, onStateChange func(*model.ServerState, ...StateChange)) *AccServerInstance {
|
||||
return &AccServerInstance{
|
||||
Model: server,
|
||||
State: &model.ServerState{PlayerCount: 0},
|
||||
OnStateChange: onStateChange,
|
||||
}
|
||||
return &AccServerInstance{
|
||||
Model: server,
|
||||
State: &model.ServerState{PlayerCount: 0},
|
||||
OnStateChange: onStateChange,
|
||||
}
|
||||
}
|
||||
|
||||
type StateRegexHandler struct {
|
||||
*regex_handler.RegexHandler
|
||||
test string
|
||||
*regex_handler.RegexHandler
|
||||
test string
|
||||
}
|
||||
|
||||
func NewRegexHandler(str string, test string) *StateRegexHandler {
|
||||
return &StateRegexHandler{
|
||||
RegexHandler: regex_handler.New(str),
|
||||
test: test,
|
||||
RegexHandler: regex_handler.New(str),
|
||||
test: test,
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *StateRegexHandler) Test(line string) bool{
|
||||
return strings.Contains(line, rh.test)
|
||||
func (rh *StateRegexHandler) Test(line string) bool {
|
||||
return strings.Contains(line, rh.test)
|
||||
}
|
||||
|
||||
func (rh *StateRegexHandler) Count(line string) int{
|
||||
var count int = 0
|
||||
rh.Contains(line, func (strs ...string) {
|
||||
if len(strs) == 2 {
|
||||
if ct, err := strconv.Atoi(strs[1]); err == nil {
|
||||
count = ct
|
||||
}
|
||||
}
|
||||
})
|
||||
return count
|
||||
func (rh *StateRegexHandler) Count(line string) int {
|
||||
var count int = 0
|
||||
rh.Contains(line, func(strs ...string) {
|
||||
if len(strs) == 2 {
|
||||
if ct, err := strconv.Atoi(strs[1]); err == nil {
|
||||
count = ct
|
||||
}
|
||||
}
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
func (rh *StateRegexHandler) Change(line string) (string, string){
|
||||
var old string = ""
|
||||
var new string = ""
|
||||
rh.Contains(line, func (strs ...string) {
|
||||
if len(strs) == 3 {
|
||||
old = strs[1]
|
||||
new = strs[2]
|
||||
}
|
||||
})
|
||||
return old, new
|
||||
func (rh *StateRegexHandler) Change(line string) (string, string) {
|
||||
var old string = ""
|
||||
var new string = ""
|
||||
rh.Contains(line, func(strs ...string) {
|
||||
if len(strs) == 3 {
|
||||
old = strs[1]
|
||||
new = strs[2]
|
||||
}
|
||||
})
|
||||
return old, new
|
||||
}
|
||||
|
||||
func TailLogFile(path string, callback func(string)) {
|
||||
file, _ := os.Open(path)
|
||||
defer file.Close()
|
||||
file, _ := os.Open(path)
|
||||
defer file.Close()
|
||||
|
||||
file.Seek(0, os.SEEK_END) // Start at end of file
|
||||
reader := bufio.NewReader(file)
|
||||
file.Seek(0, os.SEEK_END)
|
||||
reader := bufio.NewReader(file)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err == nil {
|
||||
callback(line)
|
||||
} else {
|
||||
time.Sleep(500 * time.Millisecond) // wait for new data
|
||||
}
|
||||
}
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err == nil {
|
||||
callback(line)
|
||||
} else {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type LogStateType int
|
||||
|
||||
const (
|
||||
SessionChange LogStateType = iota
|
||||
LeaderboardUpdate
|
||||
UDPCount
|
||||
ClientsOnline
|
||||
RemovingDeadConnection
|
||||
SessionChange LogStateType = iota
|
||||
LeaderboardUpdate
|
||||
UDPCount
|
||||
ClientsOnline
|
||||
RemovingDeadConnection
|
||||
)
|
||||
|
||||
var logStateContain = map[LogStateType]string {
|
||||
SessionChange: "Session changed",
|
||||
LeaderboardUpdate: "Updated leaderboard for",
|
||||
UDPCount: "Udp message count",
|
||||
ClientsOnline: "client(s) online",
|
||||
RemovingDeadConnection: "Removing dead connection",
|
||||
var logStateContain = map[LogStateType]string{
|
||||
SessionChange: "Session changed",
|
||||
LeaderboardUpdate: "Updated leaderboard for",
|
||||
UDPCount: "Udp message count",
|
||||
ClientsOnline: "client(s) online",
|
||||
RemovingDeadConnection: "Removing dead connection",
|
||||
}
|
||||
|
||||
var sessionChangeRegex = NewRegexHandler(`Session changed: (\w+) -> (\w+)`, logStateContain[SessionChange])
|
||||
@@ -115,75 +117,77 @@ var udpCountRegex = NewRegexHandler(`Udp message count (\d+) client`, logStateCo
|
||||
var clientsOnlineRegex = NewRegexHandler(`(\d+) client\(s\) online`, logStateContain[ClientsOnline])
|
||||
var removingDeadConnectionsRegex = NewRegexHandler(`Removing dead connection`, logStateContain[RemovingDeadConnection])
|
||||
|
||||
var logStateRegex = map[LogStateType]*StateRegexHandler {
|
||||
SessionChange: sessionChangeRegex,
|
||||
LeaderboardUpdate: leaderboardUpdateRegex,
|
||||
UDPCount: udpCountRegex,
|
||||
ClientsOnline: clientsOnlineRegex,
|
||||
RemovingDeadConnection: removingDeadConnectionsRegex,
|
||||
var logStateRegex = map[LogStateType]*StateRegexHandler{
|
||||
SessionChange: sessionChangeRegex,
|
||||
LeaderboardUpdate: leaderboardUpdateRegex,
|
||||
UDPCount: udpCountRegex,
|
||||
ClientsOnline: clientsOnlineRegex,
|
||||
RemovingDeadConnection: removingDeadConnectionsRegex,
|
||||
}
|
||||
|
||||
func (instance *AccServerInstance) HandleLogLine(line string) {
|
||||
for logState, regexHandler := range logStateRegex {
|
||||
if (regexHandler.Test(line)) {
|
||||
switch logState {
|
||||
case LeaderboardUpdate:
|
||||
case UDPCount:
|
||||
case ClientsOnline:
|
||||
count := regexHandler.Count(line)
|
||||
instance.UpdatePlayerCount(count)
|
||||
case SessionChange:
|
||||
_, new := regexHandler.Change(line)
|
||||
instance.UpdateSessionChange(new)
|
||||
case RemovingDeadConnection:
|
||||
instance.UpdatePlayerCount(instance.State.PlayerCount - 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
for logState, regexHandler := range logStateRegex {
|
||||
if regexHandler.Test(line) {
|
||||
switch logState {
|
||||
case LeaderboardUpdate:
|
||||
case UDPCount:
|
||||
case ClientsOnline:
|
||||
count := regexHandler.Count(line)
|
||||
instance.UpdatePlayerCount(count)
|
||||
case SessionChange:
|
||||
_, new := regexHandler.Change(line)
|
||||
|
||||
trackSession := model.ToTrackSession(new)
|
||||
instance.UpdateSessionChange(trackSession)
|
||||
case RemovingDeadConnection:
|
||||
instance.UpdatePlayerCount(instance.State.PlayerCount - 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (instance *AccServerInstance) UpdateState(callback func(state *model.ServerState, changes *[]StateChange)) {
|
||||
state := instance.State
|
||||
changes := []StateChange{}
|
||||
state.Lock()
|
||||
defer state.Unlock()
|
||||
callback(state, &changes)
|
||||
if (len(changes) > 0) {
|
||||
instance.OnStateChange(state, changes...)
|
||||
}
|
||||
state := instance.State
|
||||
changes := []StateChange{}
|
||||
state.Lock()
|
||||
defer state.Unlock()
|
||||
callback(state, &changes)
|
||||
if len(changes) > 0 {
|
||||
instance.OnStateChange(state, changes...)
|
||||
}
|
||||
}
|
||||
|
||||
func (instance *AccServerInstance) UpdatePlayerCount(count int) {
|
||||
if (count < 0) {
|
||||
return
|
||||
}
|
||||
instance.UpdateState(func (state *model.ServerState, changes *[]StateChange) {
|
||||
if (count == state.PlayerCount) {
|
||||
return
|
||||
}
|
||||
if (count > 0 && state.PlayerCount == 0) {
|
||||
state.SessionStart = time.Now()
|
||||
*changes = append(*changes, Session)
|
||||
} else if (count == 0) {
|
||||
state.SessionStart = time.Time{}
|
||||
*changes = append(*changes, Session)
|
||||
}
|
||||
state.PlayerCount = count
|
||||
*changes = append(*changes, PlayerCount)
|
||||
})
|
||||
if count < 0 {
|
||||
return
|
||||
}
|
||||
instance.UpdateState(func(state *model.ServerState, changes *[]StateChange) {
|
||||
if count == state.PlayerCount {
|
||||
return
|
||||
}
|
||||
if count > 0 && state.PlayerCount == 0 {
|
||||
state.SessionStart = time.Now()
|
||||
*changes = append(*changes, Session)
|
||||
} else if count == 0 {
|
||||
state.SessionStart = time.Time{}
|
||||
*changes = append(*changes, Session)
|
||||
}
|
||||
state.PlayerCount = count
|
||||
*changes = append(*changes, PlayerCount)
|
||||
})
|
||||
}
|
||||
|
||||
func (instance *AccServerInstance) UpdateSessionChange(session string) {
|
||||
instance.UpdateState(func (state *model.ServerState, changes *[]StateChange) {
|
||||
if (session == state.Session) {
|
||||
return
|
||||
}
|
||||
if (state.PlayerCount > 0) {
|
||||
state.SessionStart = time.Now()
|
||||
} else {
|
||||
state.SessionStart = time.Time{}
|
||||
}
|
||||
state.Session = session
|
||||
*changes = append(*changes, Session)
|
||||
})
|
||||
}
|
||||
func (instance *AccServerInstance) UpdateSessionChange(session model.TrackSession) {
|
||||
instance.UpdateState(func(state *model.ServerState, changes *[]StateChange) {
|
||||
if session == state.Session {
|
||||
return
|
||||
}
|
||||
if state.PlayerCount > 0 {
|
||||
state.SessionStart = time.Now()
|
||||
} else {
|
||||
state.SessionStart = time.Time{}
|
||||
}
|
||||
state.Session = session
|
||||
*changes = append(*changes, Session)
|
||||
})
|
||||
}
|
||||
|
||||
169
local/utl/websocket/websocket.go
Normal file
169
local/utl/websocket/websocket.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/logging"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type WebSocketConnection struct {
|
||||
conn *websocket.Conn
|
||||
serverID *uuid.UUID
|
||||
userID *uuid.UUID
|
||||
}
|
||||
|
||||
type WebSocketService struct {
|
||||
connections sync.Map
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewWebSocketService() *WebSocketService {
|
||||
return &WebSocketService{}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) AddConnection(connID string, conn *websocket.Conn, userID *uuid.UUID) {
|
||||
wsConn := &WebSocketConnection{
|
||||
conn: conn,
|
||||
userID: userID,
|
||||
}
|
||||
ws.connections.Store(connID, wsConn)
|
||||
logging.Info("WebSocket connection added: %s for user: %v", connID, userID)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) RemoveConnection(connID string) {
|
||||
if conn, exists := ws.connections.LoadAndDelete(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
wsConn.conn.Close()
|
||||
}
|
||||
}
|
||||
logging.Info("WebSocket connection removed: %s", connID)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) SetServerID(connID string, serverID uuid.UUID) {
|
||||
if conn, exists := ws.connections.Load(connID); exists {
|
||||
if wsConn, ok := conn.(*WebSocketConnection); ok {
|
||||
wsConn.serverID = &serverID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastStep(serverID uuid.UUID, step model.ServerCreationStep, status model.StepStatus, message string, errorMsg string) {
|
||||
stepMsg := model.StepMessage{
|
||||
Step: step,
|
||||
Status: status,
|
||||
Message: message,
|
||||
Error: errorMsg,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeStep,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: stepMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastSteamOutput(serverID uuid.UUID, output string, isError bool) {
|
||||
steamMsg := model.SteamOutputMessage{
|
||||
Output: output,
|
||||
IsError: isError,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeSteamOutput,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: steamMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastError(serverID uuid.UUID, error string, details string) {
|
||||
errorMsg := model.ErrorMessage{
|
||||
Error: error,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeError,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: errorMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastComplete(serverID uuid.UUID, success bool, message string) {
|
||||
completeMsg := model.CompleteMessage{
|
||||
ServerID: serverID,
|
||||
Success: success,
|
||||
Message: message,
|
||||
}
|
||||
|
||||
wsMsg := model.WebSocketMessage{
|
||||
Type: model.MessageTypeComplete,
|
||||
ServerID: &serverID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Data: completeMsg,
|
||||
}
|
||||
|
||||
ws.broadcastToServer(serverID, wsMsg)
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) broadcastToServer(serverID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logging.Error("Failed to marshal WebSocket message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if wsConn.serverID != nil && *wsConn.serverID == serverID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) BroadcastToUser(userID uuid.UUID, message model.WebSocketMessage) {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
logging.Error("Failed to marshal WebSocket message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
if wsConn, ok := value.(*WebSocketConnection); ok {
|
||||
if wsConn.userID != nil && *wsConn.userID == userID {
|
||||
if err := wsConn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
logging.Error("Failed to send WebSocket message to connection %s: %v", key, err)
|
||||
ws.RemoveConnection(key.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (ws *WebSocketService) GetActiveConnections() int {
|
||||
count := 0
|
||||
ws.connections.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
1186
swagger/docs.go
1186
swagger/docs.go
File diff suppressed because it is too large
Load Diff
1183
swagger/swagger.json
1183
swagger/swagger.json
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
basePath: /api/v1
|
||||
basePath: /v1
|
||||
definitions:
|
||||
error_handler.ErrorResponse:
|
||||
properties:
|
||||
@@ -92,6 +92,35 @@ definitions:
|
||||
- StatusRestarting
|
||||
- StatusStarting
|
||||
- StatusRunning
|
||||
model.Steam2FARequest:
|
||||
properties:
|
||||
completedAt:
|
||||
type: string
|
||||
errorMsg:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
requestTime:
|
||||
type: string
|
||||
serverId:
|
||||
type: string
|
||||
status:
|
||||
$ref: '#/definitions/model.Steam2FAStatus'
|
||||
type: object
|
||||
model.Steam2FAStatus:
|
||||
enum:
|
||||
- idle
|
||||
- pending
|
||||
- complete
|
||||
- error
|
||||
type: string
|
||||
x-enum-varnames:
|
||||
- Steam2FAStatusIdle
|
||||
- Steam2FAStatusPending
|
||||
- Steam2FAStatusComplete
|
||||
- Steam2FAStatusError
|
||||
model.User:
|
||||
properties:
|
||||
id:
|
||||
@@ -103,11 +132,20 @@ definitions:
|
||||
username:
|
||||
type: string
|
||||
type: object
|
||||
host: localhost:3000
|
||||
service.UpdateUserRequest:
|
||||
properties:
|
||||
password:
|
||||
type: string
|
||||
roleId:
|
||||
type: string
|
||||
username:
|
||||
type: string
|
||||
type: object
|
||||
host: acc-api.jurmanovic.com
|
||||
info:
|
||||
contact:
|
||||
name: ACC Server Manager Support
|
||||
url: https://github.com/yourusername/acc-server-manager
|
||||
url: https://github.com/FJurmanovic/acc-server-manager
|
||||
description: API for managing Assetto Corsa Competizione dedicated servers
|
||||
license:
|
||||
name: MIT
|
||||
@@ -115,6 +153,62 @@ info:
|
||||
title: ACC Server Manager API
|
||||
version: "1.0"
|
||||
paths:
|
||||
/api/server:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all ACC servers with filtering options
|
||||
parameters:
|
||||
- in: query
|
||||
name: name
|
||||
type: string
|
||||
- in: query
|
||||
name: page
|
||||
type: integer
|
||||
- in: query
|
||||
name: pageSize
|
||||
type: integer
|
||||
- in: query
|
||||
name: serverID
|
||||
type: string
|
||||
- in: query
|
||||
name: serviceName
|
||||
type: string
|
||||
- in: query
|
||||
name: sortBy
|
||||
type: string
|
||||
- in: query
|
||||
name: sortDesc
|
||||
type: boolean
|
||||
- in: query
|
||||
name: status
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of servers
|
||||
schema:
|
||||
items:
|
||||
$ref: '#/definitions/model.ServerAPI'
|
||||
type: array
|
||||
"400":
|
||||
description: Invalid filter parameters
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: List all servers (API format)
|
||||
tags:
|
||||
- Server
|
||||
/auth/login:
|
||||
post:
|
||||
consumes:
|
||||
@@ -157,6 +251,228 @@ paths:
|
||||
summary: User login
|
||||
tags:
|
||||
- Authentication
|
||||
/auth/me:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get details of the currently authenticated user
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Current user details
|
||||
schema:
|
||||
$ref: '#/definitions/model.User'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: User not found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get current user details
|
||||
tags:
|
||||
- Authentication
|
||||
/auth/open-token:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Generate an open token for a user
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: JWT token
|
||||
schema:
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
type: object
|
||||
"400":
|
||||
description: Invalid request body
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"401":
|
||||
description: Invalid credentials
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
summary: Generate an open token
|
||||
tags:
|
||||
- Authentication
|
||||
/lookup/car-models:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available ACC car models with their identifiers
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of car models
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
class:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get available car models
|
||||
tags:
|
||||
- Lookups
|
||||
/lookup/cup-categories:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available racing cup categories
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of cup categories
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
id:
|
||||
type: number
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get cup categories
|
||||
tags:
|
||||
- Lookups
|
||||
/lookup/driver-categories:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all driver categories (Bronze, Silver, Gold, Platinum)
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of driver categories
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
description:
|
||||
type: string
|
||||
id:
|
||||
type: number
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get driver categories
|
||||
tags:
|
||||
- Lookups
|
||||
/lookup/session-types:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available session types (Practice, Qualifying,
|
||||
Race)
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of session types
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get session types
|
||||
tags:
|
||||
- Lookups
|
||||
/lookup/tracks:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available ACC tracks with their identifiers
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of tracks
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get available tracks
|
||||
tags:
|
||||
- Lookups
|
||||
/membership:
|
||||
get:
|
||||
consumes:
|
||||
@@ -238,117 +554,141 @@ paths:
|
||||
summary: Create a new user
|
||||
tags:
|
||||
- User Management
|
||||
/server:
|
||||
get:
|
||||
/membership/{id}:
|
||||
delete:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all ACC servers with filtering options
|
||||
description: Delete a specific user by ID
|
||||
parameters:
|
||||
- in: query
|
||||
name: name
|
||||
type: string
|
||||
- in: query
|
||||
name: page
|
||||
type: integer
|
||||
- in: query
|
||||
name: pageSize
|
||||
type: integer
|
||||
- in: query
|
||||
name: serverID
|
||||
type: string
|
||||
- in: query
|
||||
name: serviceName
|
||||
type: string
|
||||
- in: query
|
||||
name: sortBy
|
||||
type: string
|
||||
- in: query
|
||||
name: sortDesc
|
||||
type: boolean
|
||||
- in: query
|
||||
name: status
|
||||
- description: User ID (UUID format)
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of servers
|
||||
schema:
|
||||
items:
|
||||
$ref: '#/definitions/model.ServerAPI'
|
||||
type: array
|
||||
"204":
|
||||
description: User successfully deleted
|
||||
"400":
|
||||
description: Invalid filter parameters
|
||||
description: Invalid user ID format
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: User not found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: List all servers (API format)
|
||||
summary: Delete user
|
||||
tags:
|
||||
- Server
|
||||
/v1/lookup/car-models:
|
||||
- User Management
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available ACC car models with their identifiers
|
||||
description: Get detailed information about a specific user
|
||||
parameters:
|
||||
- description: User ID (UUID format)
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of car models
|
||||
description: User details
|
||||
schema:
|
||||
$ref: '#/definitions/model.User'
|
||||
"400":
|
||||
description: Invalid user ID format
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: User not found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get user by ID
|
||||
tags:
|
||||
- User Management
|
||||
put:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Update user details by ID
|
||||
parameters:
|
||||
- description: User ID (UUID format)
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
- description: Updated user details
|
||||
in: body
|
||||
name: user
|
||||
required: true
|
||||
schema:
|
||||
$ref: '#/definitions/service.UpdateUserRequest'
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Updated user details
|
||||
schema:
|
||||
$ref: '#/definitions/model.User'
|
||||
"400":
|
||||
description: Invalid request body or ID format
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: User not found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Update user
|
||||
tags:
|
||||
- User Management
|
||||
/membership/roles:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available user roles
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of roles
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
class:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
$ref: '#/definitions/model.Role'
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get available car models
|
||||
tags:
|
||||
- Lookups
|
||||
/v1/lookup/cup-categories:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available racing cup categories
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of cup categories
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
id:
|
||||
type: number
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
@@ -357,111 +697,10 @@ paths:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get cup categories
|
||||
summary: Get all roles
|
||||
tags:
|
||||
- Lookups
|
||||
/v1/lookup/driver-categories:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all driver categories (Bronze, Silver, Gold, Platinum)
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of driver categories
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
description:
|
||||
type: string
|
||||
id:
|
||||
type: number
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get driver categories
|
||||
tags:
|
||||
- Lookups
|
||||
/v1/lookup/session-types:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available session types (Practice, Qualifying,
|
||||
Race)
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of session types
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get session types
|
||||
tags:
|
||||
- Lookups
|
||||
/v1/lookup/tracks:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a list of all available ACC tracks with their identifiers
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: List of tracks
|
||||
schema:
|
||||
items:
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Get available tracks
|
||||
tags:
|
||||
- Lookups
|
||||
/v1/server:
|
||||
- User Management
|
||||
/server:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -556,7 +795,49 @@ paths:
|
||||
summary: Create a new ACC server
|
||||
tags:
|
||||
- Server
|
||||
/v1/server/{id}:
|
||||
/server/{id}:
|
||||
delete:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Delete an existing ACC server
|
||||
parameters:
|
||||
- description: Server ID (UUID format)
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: Deleted server details
|
||||
schema:
|
||||
type: object
|
||||
"400":
|
||||
description: Invalid server data or ID
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"401":
|
||||
description: Unauthorized
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: Server not found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal server error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
security:
|
||||
- BearerAuth: []
|
||||
summary: Delete an ACC server
|
||||
tags:
|
||||
- Server
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -643,7 +924,7 @@ paths:
|
||||
summary: Update an ACC server
|
||||
tags:
|
||||
- Server
|
||||
/v1/server/{id}/config:
|
||||
/server/{id}/config:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -688,7 +969,7 @@ paths:
|
||||
summary: List available configuration files
|
||||
tags:
|
||||
- Server Configuration
|
||||
/v1/server/{id}/config/{file}:
|
||||
/server/{id}/config/{file}:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -792,20 +1073,7 @@ paths:
|
||||
summary: Update server configuration file
|
||||
tags:
|
||||
- Server Configuration
|
||||
/v1/service-control:
|
||||
get:
|
||||
description: Return service control status
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
summary: Return service control status
|
||||
tags:
|
||||
- service-control
|
||||
/v1/service-control/{service}:
|
||||
/server/{id}/service/{service}:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -849,7 +1117,7 @@ paths:
|
||||
summary: Get service status
|
||||
tags:
|
||||
- Service Control
|
||||
/v1/service-control/restart:
|
||||
/server/{id}/service/restart:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -899,7 +1167,7 @@ paths:
|
||||
summary: Restart a Windows service
|
||||
tags:
|
||||
- Service Control
|
||||
/v1/service-control/start:
|
||||
/server/{id}/service/start:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -953,7 +1221,7 @@ paths:
|
||||
summary: Start a Windows service
|
||||
tags:
|
||||
- Service Control
|
||||
/v1/service-control/stop:
|
||||
/server/{id}/service/stop:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -1007,7 +1275,7 @@ paths:
|
||||
summary: Stop a Windows service
|
||||
tags:
|
||||
- Service Control
|
||||
/v1/state-history:
|
||||
/state-history:
|
||||
get:
|
||||
description: Return StateHistorys
|
||||
responses:
|
||||
@@ -1020,7 +1288,7 @@ paths:
|
||||
summary: Return StateHistorys
|
||||
tags:
|
||||
- StateHistory
|
||||
/v1/state-history/statistics:
|
||||
/state-history/statistics:
|
||||
get:
|
||||
description: Return StateHistorys
|
||||
responses:
|
||||
@@ -1033,8 +1301,136 @@ paths:
|
||||
summary: Return StateHistorys
|
||||
tags:
|
||||
- StateHistory
|
||||
/steam2fa/{id}:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get a specific Steam 2FA authentication request by ID
|
||||
parameters:
|
||||
- description: 2FA Request ID
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/model.Steam2FARequest'
|
||||
"404":
|
||||
description: Not Found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
summary: Get 2FA request
|
||||
tags:
|
||||
- Steam 2FA
|
||||
/steam2fa/{id}/cancel:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Cancel a Steam 2FA authentication request
|
||||
parameters:
|
||||
- description: 2FA Request ID
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/model.Steam2FARequest'
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: Not Found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
summary: Cancel 2FA request
|
||||
tags:
|
||||
- Steam 2FA
|
||||
/steam2fa/{id}/complete:
|
||||
post:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Mark a Steam 2FA authentication request as completed
|
||||
parameters:
|
||||
- description: 2FA Request ID
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
type: string
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
$ref: '#/definitions/model.Steam2FARequest'
|
||||
"400":
|
||||
description: Bad Request
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"404":
|
||||
description: Not Found
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
summary: Complete 2FA request
|
||||
tags:
|
||||
- Steam 2FA
|
||||
/steam2fa/pending:
|
||||
get:
|
||||
consumes:
|
||||
- application/json
|
||||
description: Get all pending Steam 2FA authentication requests
|
||||
produces:
|
||||
- application/json
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
items:
|
||||
$ref: '#/definitions/model.Steam2FARequest'
|
||||
type: array
|
||||
"500":
|
||||
description: Internal Server Error
|
||||
schema:
|
||||
$ref: '#/definitions/error_handler.ErrorResponse'
|
||||
summary: Get pending 2FA requests
|
||||
tags:
|
||||
- Steam 2FA
|
||||
/system/health:
|
||||
get:
|
||||
description: Return service control status
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
schema:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
summary: Return service control status
|
||||
tags:
|
||||
- system
|
||||
schemes:
|
||||
- http
|
||||
- https
|
||||
securityDefinitions:
|
||||
BearerAuth:
|
||||
|
||||
@@ -4,22 +4,26 @@ import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GenerateTestToken creates a JWT token for testing purposes
|
||||
func GenerateTestToken() (string, error) {
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "test_user",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := jwt.GenerateToken(user)
|
||||
testSecret := os.Getenv("JWT_SECRET")
|
||||
if testSecret == "" {
|
||||
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(testSecret)
|
||||
|
||||
token, err := jwtHandler.GenerateToken(user.ID.String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate test token: %w", err)
|
||||
}
|
||||
@@ -27,8 +31,6 @@ func GenerateTestToken() (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// MustGenerateTestToken generates a test token and panics if it fails
|
||||
// This is useful for test setup where failing to generate a token is a fatal error
|
||||
func MustGenerateTestToken() string {
|
||||
token, err := GenerateTestToken()
|
||||
if err != nil {
|
||||
@@ -37,17 +39,20 @@ func MustGenerateTestToken() string {
|
||||
return token
|
||||
}
|
||||
|
||||
// GenerateTestTokenWithExpiry creates a JWT token with a specific expiry time
|
||||
func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
|
||||
// Create test user
|
||||
testSecret := os.Getenv("JWT_SECRET")
|
||||
if testSecret == "" {
|
||||
testSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(testSecret)
|
||||
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "test_user",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Generate JWT token with custom expiry
|
||||
token, err := jwt.GenerateTokenWithExpiry(user, expiryTime)
|
||||
token, err := jwtHandler.GenerateTokenWithExpiry(user, expiryTime)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate test token with expiry: %w", err)
|
||||
}
|
||||
@@ -55,8 +60,6 @@ func GenerateTestTokenWithExpiry(expiryTime time.Time) (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// AddAuthHeader adds a test auth token to the request headers
|
||||
// This is a convenience method for tests that need to authenticate requests
|
||||
func AddAuthHeader(headers map[string]string) (map[string]string, error) {
|
||||
token, err := GenerateTestToken()
|
||||
if err != nil {
|
||||
@@ -71,7 +74,6 @@ func AddAuthHeader(headers map[string]string) (map[string]string, error) {
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// MustAddAuthHeader adds a test auth token to the request headers and panics if it fails
|
||||
func MustAddAuthHeader(headers map[string]string) map[string]string {
|
||||
result, err := AddAuthHeader(headers)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,26 +8,20 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockAuthMiddleware provides a test implementation of AuthMiddleware
|
||||
// that can be used as a drop-in replacement for the real AuthMiddleware
|
||||
type MockAuthMiddleware struct{}
|
||||
|
||||
// NewMockAuthMiddleware creates a new MockAuthMiddleware
|
||||
func NewMockAuthMiddleware() *MockAuthMiddleware {
|
||||
return &MockAuthMiddleware{}
|
||||
}
|
||||
|
||||
// Authenticate is a middleware that allows all requests without authentication for testing
|
||||
func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
// Set a mock user ID in context
|
||||
mockUserID := uuid.New().String()
|
||||
ctx.Locals("userID", mockUserID)
|
||||
|
||||
// Set mock user info
|
||||
mockUserInfo := &middleware.CachedUserInfo{
|
||||
UserID: mockUserID,
|
||||
Username: "test_user",
|
||||
RoleName: "Admin", // Admin role to bypass permission checks
|
||||
RoleName: "Admin",
|
||||
Permissions: map[string]bool{"*": true},
|
||||
CachedAt: time.Now(),
|
||||
}
|
||||
@@ -38,21 +32,18 @@ func (m *MockAuthMiddleware) Authenticate(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
// HasPermission is a middleware that allows all permission checks to pass for testing
|
||||
func (m *MockAuthMiddleware) HasPermission(requiredPermission string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRateLimit is a test implementation that allows all requests
|
||||
func (m *MockAuthMiddleware) AuthRateLimit() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireHTTPS is a test implementation that allows all HTTP requests
|
||||
func (m *MockAuthMiddleware) RequireHTTPS() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockConfigRepository provides a mock implementation of ConfigRepository
|
||||
type MockConfigRepository struct {
|
||||
configs map[string]*model.Config
|
||||
shouldFailGet bool
|
||||
@@ -21,7 +20,6 @@ func NewMockConfigRepository() *MockConfigRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig mocks the UpdateConfig method
|
||||
func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.Config) *model.Config {
|
||||
if m.shouldFailUpdate {
|
||||
return nil
|
||||
@@ -36,18 +34,15 @@ func (m *MockConfigRepository) UpdateConfig(ctx context.Context, config *model.C
|
||||
return config
|
||||
}
|
||||
|
||||
// SetShouldFailUpdate configures the mock to fail on UpdateConfig calls
|
||||
func (m *MockConfigRepository) SetShouldFailUpdate(shouldFail bool) {
|
||||
m.shouldFailUpdate = shouldFail
|
||||
}
|
||||
|
||||
// GetConfig retrieves a config by server ID and config file
|
||||
func (m *MockConfigRepository) GetConfig(serverID uuid.UUID, configFile string) *model.Config {
|
||||
key := serverID.String() + "_" + configFile
|
||||
return m.configs[key]
|
||||
}
|
||||
|
||||
// MockServerRepository provides a mock implementation of ServerRepository
|
||||
type MockServerRepository struct {
|
||||
servers map[uuid.UUID]*model.Server
|
||||
shouldFailGet bool
|
||||
@@ -59,7 +54,6 @@ func NewMockServerRepository() *MockServerRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// GetByID mocks the GetByID method
|
||||
func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*model.Server, error) {
|
||||
if m.shouldFailGet {
|
||||
return nil, errors.New("server not found")
|
||||
@@ -88,17 +82,14 @@ func (m *MockServerRepository) GetByID(ctx context.Context, id interface{}) (*mo
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// AddServer adds a server to the mock repository
|
||||
func (m *MockServerRepository) AddServer(server *model.Server) {
|
||||
m.servers[server.ID] = server
|
||||
}
|
||||
|
||||
// SetShouldFailGet configures the mock to fail on GetByID calls
|
||||
func (m *MockServerRepository) SetShouldFailGet(shouldFail bool) {
|
||||
m.shouldFailGet = shouldFail
|
||||
}
|
||||
|
||||
// MockServerService provides a mock implementation of ServerService
|
||||
type MockServerService struct {
|
||||
startRuntimeCalled bool
|
||||
startRuntimeServer *model.Server
|
||||
@@ -108,23 +99,19 @@ func NewMockServerService() *MockServerService {
|
||||
return &MockServerService{}
|
||||
}
|
||||
|
||||
// StartAccServerRuntime mocks the StartAccServerRuntime method
|
||||
func (m *MockServerService) StartAccServerRuntime(server *model.Server) {
|
||||
m.startRuntimeCalled = true
|
||||
m.startRuntimeServer = server
|
||||
}
|
||||
|
||||
// WasStartRuntimeCalled returns whether StartAccServerRuntime was called
|
||||
func (m *MockServerService) WasStartRuntimeCalled() bool {
|
||||
return m.startRuntimeCalled
|
||||
}
|
||||
|
||||
// GetStartRuntimeServer returns the server passed to StartAccServerRuntime
|
||||
func (m *MockServerService) GetStartRuntimeServer() *model.Server {
|
||||
return m.startRuntimeServer
|
||||
}
|
||||
|
||||
// Reset resets the mock state
|
||||
func (m *MockServerService) Reset() {
|
||||
m.startRuntimeCalled = false
|
||||
m.startRuntimeServer = nil
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// MockStateHistoryRepository provides a mock implementation of StateHistoryRepository
|
||||
type MockStateHistoryRepository struct {
|
||||
stateHistories []model.StateHistory
|
||||
shouldFailGet bool
|
||||
@@ -21,7 +20,6 @@ func NewMockStateHistoryRepository() *MockStateHistoryRepository {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll mocks the GetAll method
|
||||
func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.StateHistoryFilter) (*[]model.StateHistory, error) {
|
||||
if m.shouldFailGet {
|
||||
return nil, errors.New("failed to get state history")
|
||||
@@ -37,13 +35,11 @@ func (m *MockStateHistoryRepository) GetAll(ctx context.Context, filter *model.S
|
||||
return &filtered, nil
|
||||
}
|
||||
|
||||
// Insert mocks the Insert method
|
||||
func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *model.StateHistory) error {
|
||||
if m.shouldFailInsert {
|
||||
return errors.New("failed to insert state history")
|
||||
}
|
||||
|
||||
// Simulate BeforeCreate hook
|
||||
if stateHistory.ID == uuid.Nil {
|
||||
stateHistory.ID = uuid.New()
|
||||
}
|
||||
@@ -55,7 +51,6 @@ func (m *MockStateHistoryRepository) Insert(ctx context.Context, stateHistory *m
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastSessionID mocks the GetLastSessionID method
|
||||
func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serverID uuid.UUID) (uuid.UUID, error) {
|
||||
for i := len(m.stateHistories) - 1; i >= 0; i-- {
|
||||
if m.stateHistories[i].ServerID == serverID {
|
||||
@@ -65,7 +60,6 @@ func (m *MockStateHistoryRepository) GetLastSessionID(ctx context.Context, serve
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
|
||||
// Helper methods for filtering
|
||||
func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter *model.StateHistoryFilter) bool {
|
||||
if filter == nil {
|
||||
return true
|
||||
@@ -93,7 +87,6 @@ func (m *MockStateHistoryRepository) matchesFilter(sh model.StateHistory, filter
|
||||
return true
|
||||
}
|
||||
|
||||
// Helper methods for testing configuration
|
||||
func (m *MockStateHistoryRepository) SetShouldFailGet(shouldFail bool) {
|
||||
m.shouldFailGet = shouldFail
|
||||
}
|
||||
@@ -102,7 +95,6 @@ func (m *MockStateHistoryRepository) SetShouldFailInsert(shouldFail bool) {
|
||||
m.shouldFailInsert = shouldFail
|
||||
}
|
||||
|
||||
// AddStateHistory adds a state history entry to the mock repository
|
||||
func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHistory) {
|
||||
if stateHistory.ID == uuid.Nil {
|
||||
stateHistory.ID = uuid.New()
|
||||
@@ -113,22 +105,18 @@ func (m *MockStateHistoryRepository) AddStateHistory(stateHistory model.StateHis
|
||||
m.stateHistories = append(m.stateHistories, stateHistory)
|
||||
}
|
||||
|
||||
// GetCount returns the number of state history entries
|
||||
func (m *MockStateHistoryRepository) GetCount() int {
|
||||
return len(m.stateHistories)
|
||||
}
|
||||
|
||||
// Clear removes all state history entries
|
||||
func (m *MockStateHistoryRepository) Clear() {
|
||||
m.stateHistories = make([]model.StateHistory, 0)
|
||||
}
|
||||
|
||||
// GetSummaryStats calculates peak players, total sessions, and average players for mock data
|
||||
func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter *model.StateHistoryFilter) (model.StateHistoryStats, error) {
|
||||
var stats model.StateHistoryStats
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
@@ -139,7 +127,6 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// Calculate statistics
|
||||
sessionMap := make(map[string]bool)
|
||||
totalPlayers := 0
|
||||
|
||||
@@ -159,11 +146,9 @@ func (m *MockStateHistoryRepository) GetSummaryStats(ctx context.Context, filter
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetTotalPlaytime calculates total playtime in minutes for mock data
|
||||
func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filter *model.StateHistoryFilter) (int, error) {
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
@@ -174,7 +159,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Group by session and calculate durations
|
||||
sessionMap := make(map[string][]model.StateHistory)
|
||||
for _, entry := range filteredEntries {
|
||||
sessionID := entry.SessionID.String()
|
||||
@@ -184,7 +168,6 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
|
||||
totalMinutes := 0
|
||||
for _, sessionEntries := range sessionMap {
|
||||
if len(sessionEntries) > 1 {
|
||||
// Sort by date (simple approach for mock)
|
||||
minTime := sessionEntries[0].DateCreated
|
||||
maxTime := sessionEntries[0].DateCreated
|
||||
hasPlayers := false
|
||||
@@ -211,26 +194,22 @@ func (m *MockStateHistoryRepository) GetTotalPlaytime(ctx context.Context, filte
|
||||
return totalMinutes, nil
|
||||
}
|
||||
|
||||
// GetPlayerCountOverTime returns downsampled player count data for mock
|
||||
func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context, filter *model.StateHistoryFilter) ([]model.PlayerCountPoint, error) {
|
||||
var points []model.PlayerCountPoint
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by hour (simple mock implementation)
|
||||
hourMap := make(map[string][]int)
|
||||
for _, entry := range filteredEntries {
|
||||
hourKey := entry.DateCreated.Format("2006-01-02 15")
|
||||
hourMap[hourKey] = append(hourMap[hourKey], entry.PlayerCount)
|
||||
}
|
||||
|
||||
// Calculate averages per hour
|
||||
for hourKey, counts := range hourMap {
|
||||
total := 0
|
||||
for _, count := range counts {
|
||||
@@ -247,20 +226,17 @@ func (m *MockStateHistoryRepository) GetPlayerCountOverTime(ctx context.Context,
|
||||
return points, nil
|
||||
}
|
||||
|
||||
// GetSessionTypes counts sessions by type for mock
|
||||
func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter *model.StateHistoryFilter) ([]model.SessionCount, error) {
|
||||
var sessionTypes []model.SessionCount
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by session type
|
||||
sessionMap := make(map[string]map[string]bool) // session -> sessionID -> bool
|
||||
sessionMap := make(map[model.TrackSession]map[string]bool)
|
||||
for _, entry := range filteredEntries {
|
||||
if sessionMap[entry.Session] == nil {
|
||||
sessionMap[entry.Session] = make(map[string]bool)
|
||||
@@ -268,7 +244,6 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
|
||||
sessionMap[entry.Session][entry.SessionID.String()] = true
|
||||
}
|
||||
|
||||
// Count unique sessions per type
|
||||
for sessionType, sessions := range sessionMap {
|
||||
sessionTypes = append(sessionTypes, model.SessionCount{
|
||||
Name: sessionType,
|
||||
@@ -279,20 +254,17 @@ func (m *MockStateHistoryRepository) GetSessionTypes(ctx context.Context, filter
|
||||
return sessionTypes, nil
|
||||
}
|
||||
|
||||
// GetDailyActivity counts sessions per day for mock
|
||||
func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filter *model.StateHistoryFilter) ([]model.DailyActivity, error) {
|
||||
var dailyActivity []model.DailyActivity
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by day
|
||||
dayMap := make(map[string]map[string]bool) // date -> sessionID -> bool
|
||||
dayMap := make(map[string]map[string]bool)
|
||||
for _, entry := range filteredEntries {
|
||||
dateKey := entry.DateCreated.Format("2006-01-02")
|
||||
if dayMap[dateKey] == nil {
|
||||
@@ -301,7 +273,6 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
|
||||
dayMap[dateKey][entry.SessionID.String()] = true
|
||||
}
|
||||
|
||||
// Count unique sessions per day
|
||||
for date, sessions := range dayMap {
|
||||
dailyActivity = append(dailyActivity, model.DailyActivity{
|
||||
Date: date,
|
||||
@@ -312,26 +283,22 @@ func (m *MockStateHistoryRepository) GetDailyActivity(ctx context.Context, filte
|
||||
return dailyActivity, nil
|
||||
}
|
||||
|
||||
// GetRecentSessions retrieves recent sessions for mock
|
||||
func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filter *model.StateHistoryFilter) ([]model.RecentSession, error) {
|
||||
var recentSessions []model.RecentSession
|
||||
var filteredEntries []model.StateHistory
|
||||
|
||||
// Filter entries
|
||||
for _, entry := range m.stateHistories {
|
||||
if m.matchesFilter(entry, filter) {
|
||||
filteredEntries = append(filteredEntries, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Group by session
|
||||
sessionMap := make(map[string][]model.StateHistory)
|
||||
for _, entry := range filteredEntries {
|
||||
sessionID := entry.SessionID.String()
|
||||
sessionMap[sessionID] = append(sessionMap[sessionID], entry)
|
||||
}
|
||||
|
||||
// Create recent sessions (limit to 10)
|
||||
count := 0
|
||||
for _, entries := range sessionMap {
|
||||
if count >= 10 {
|
||||
@@ -339,7 +306,6 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
|
||||
}
|
||||
|
||||
if len(entries) > 0 {
|
||||
// Find min/max dates and max players
|
||||
minDate := entries[0].DateCreated
|
||||
maxDate := entries[0].DateCreated
|
||||
maxPlayers := 0
|
||||
@@ -356,11 +322,10 @@ func (m *MockStateHistoryRepository) GetRecentSessions(ctx context.Context, filt
|
||||
}
|
||||
}
|
||||
|
||||
// Only include sessions with players
|
||||
if maxPlayers > 0 {
|
||||
duration := int(maxDate.Sub(minDate).Minutes())
|
||||
recentSessions = append(recentSessions, model.RecentSession{
|
||||
ID: uint(count + 1),
|
||||
ID: entries[0].SessionID,
|
||||
Date: minDate.Format("2006-01-02 15:04:05"),
|
||||
Type: entries[0].Session,
|
||||
Track: entries[0].Track,
|
||||
|
||||
@@ -3,7 +3,6 @@ package tests
|
||||
import (
|
||||
"acc-server-manager/local/model"
|
||||
"acc-server-manager/local/utl/configs"
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
@@ -25,14 +24,12 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// TestHelper provides utilities for testing
|
||||
type TestHelper struct {
|
||||
DB *gorm.DB
|
||||
TempDir string
|
||||
TestData *TestData
|
||||
}
|
||||
|
||||
// TestData contains common test data structures
|
||||
type TestData struct {
|
||||
ServerID uuid.UUID
|
||||
Server *model.Server
|
||||
@@ -40,38 +37,29 @@ type TestData struct {
|
||||
SampleConfig *model.Configuration
|
||||
}
|
||||
|
||||
// SetTestEnv sets the required environment variables for tests
|
||||
func SetTestEnv() {
|
||||
// Set required environment variables for testing
|
||||
os.Setenv("APP_SECRET", "test-secret-key-for-testing-123456")
|
||||
os.Setenv("APP_SECRET_CODE", "test-code-for-testing-123456789012")
|
||||
os.Setenv("ENCRYPTION_KEY", "12345678901234567890123456789012")
|
||||
os.Setenv("JWT_SECRET", "test-jwt-secret-key-for-testing-123456789012345678901234567890")
|
||||
os.Setenv("ACCESS_KEY", "test-access-key-for-testing")
|
||||
// Set test-specific environment variables
|
||||
os.Setenv("TESTING_ENV", "true") // Used to bypass
|
||||
os.Setenv("TESTING_ENV", "true")
|
||||
|
||||
configs.Init()
|
||||
jwt.Init()
|
||||
}
|
||||
|
||||
// NewTestHelper creates a new test helper with in-memory database
|
||||
func NewTestHelper(t *testing.T) *TestHelper {
|
||||
// Set required environment variables
|
||||
SetTestEnv()
|
||||
|
||||
// Create temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create in-memory SQLite database for testing
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent), // Suppress SQL logs in tests
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to test database: %v", err)
|
||||
}
|
||||
|
||||
// Auto-migrate the schema
|
||||
err = db.AutoMigrate(
|
||||
&model.Server{},
|
||||
&model.Config{},
|
||||
@@ -81,7 +69,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
|
||||
&model.StateHistory{},
|
||||
)
|
||||
|
||||
// Explicitly ensure tables exist with correct structure
|
||||
if !db.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err = db.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -92,7 +79,6 @@ func NewTestHelper(t *testing.T) *TestHelper {
|
||||
t.Fatalf("Failed to migrate test database: %v", err)
|
||||
}
|
||||
|
||||
// Create test data
|
||||
testData := createTestData(t, tempDir)
|
||||
|
||||
return &TestHelper{
|
||||
@@ -102,11 +88,9 @@ func NewTestHelper(t *testing.T) *TestHelper {
|
||||
}
|
||||
}
|
||||
|
||||
// createTestData creates common test data structures
|
||||
func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
serverID := uuid.New()
|
||||
|
||||
// Create sample server
|
||||
server := &model.Server{
|
||||
ID: serverID,
|
||||
Name: "Test Server",
|
||||
@@ -117,13 +101,11 @@ func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
FromSteamCMD: false,
|
||||
}
|
||||
|
||||
// Create server directory
|
||||
serverConfigDir := filepath.Join(tempDir, "server", "cfg")
|
||||
if err := os.MkdirAll(serverConfigDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create server config directory: %v", err)
|
||||
}
|
||||
|
||||
// Sample configuration files content
|
||||
configFiles := map[string]string{
|
||||
"configuration.json": `{
|
||||
"udpPort": "9231",
|
||||
@@ -214,7 +196,6 @@ func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
}`,
|
||||
}
|
||||
|
||||
// Sample configuration struct
|
||||
sampleConfig := &model.Configuration{
|
||||
UdpPort: model.IntString(9231),
|
||||
TcpPort: model.IntString(9232),
|
||||
@@ -232,14 +213,12 @@ func createTestData(t *testing.T, tempDir string) *TestData {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTestConfigFiles creates actual config files in the test directory
|
||||
func (th *TestHelper) CreateTestConfigFiles() error {
|
||||
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
|
||||
|
||||
for filename, content := range th.TestData.ConfigFiles {
|
||||
filePath := filepath.Join(serverConfigDir, filename)
|
||||
|
||||
// Encode content to UTF-16 LE BOM format as expected by the application
|
||||
utf16Content, err := EncodeUTF16LEBOM([]byte(content))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -253,7 +232,6 @@ func (th *TestHelper) CreateTestConfigFiles() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMalformedConfigFile creates a config file with invalid JSON
|
||||
func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
|
||||
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
|
||||
filePath := filepath.Join(serverConfigDir, filename)
|
||||
@@ -261,67 +239,51 @@ func (th *TestHelper) CreateMalformedConfigFile(filename string) error {
|
||||
malformedJSON := `{
|
||||
"udpPort": "9231",
|
||||
"tcpPort": "9232"
|
||||
"maxConnections": "30" // Missing comma - invalid JSON
|
||||
"maxConnections": "30"
|
||||
}`
|
||||
|
||||
return os.WriteFile(filePath, []byte(malformedJSON), 0644)
|
||||
}
|
||||
|
||||
// RemoveConfigFile removes a config file to simulate missing file scenarios
|
||||
func (th *TestHelper) RemoveConfigFile(filename string) error {
|
||||
serverConfigDir := filepath.Join(th.TestData.Server.Path, "cfg")
|
||||
filePath := filepath.Join(serverConfigDir, filename)
|
||||
return os.Remove(filePath)
|
||||
}
|
||||
|
||||
// InsertTestServer inserts the test server into the database
|
||||
func (th *TestHelper) InsertTestServer() error {
|
||||
return th.DB.Create(th.TestData.Server).Error
|
||||
}
|
||||
|
||||
// CreateContext creates a test context
|
||||
func (th *TestHelper) CreateContext() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// CreateFiberCtx creates a fiber.Ctx for testing
|
||||
func (th *TestHelper) CreateFiberCtx() *fiber.Ctx {
|
||||
// Create app and request for fiber context
|
||||
app := fiber.New()
|
||||
// Create a dummy request that doesn't depend on external http objects
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
// Create the fiber context from real request/response
|
||||
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
// Store the original request for release later
|
||||
ctx.Locals("original-request", req)
|
||||
// Return the context which can be safely used in tests
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ReleaseFiberCtx properly releases a fiber context created with CreateFiberCtx
|
||||
func (th *TestHelper) ReleaseFiberCtx(app *fiber.App, ctx *fiber.Ctx) {
|
||||
if app != nil && ctx != nil {
|
||||
app.ReleaseCtx(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup performs cleanup operations after tests
|
||||
func (th *TestHelper) Cleanup() {
|
||||
// Close database connection
|
||||
if sqlDB, err := th.DB.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
|
||||
// Temporary directory is automatically cleaned up by t.TempDir()
|
||||
}
|
||||
|
||||
// LoadTestEnvFile loads environment variables from a .env file for testing
|
||||
func LoadTestEnvFile() error {
|
||||
// Try to load from .env file
|
||||
return godotenv.Load()
|
||||
}
|
||||
|
||||
// AssertNoError is a helper function to check for errors in tests
|
||||
func AssertNoError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
@@ -329,7 +291,6 @@ func AssertNoError(t *testing.T, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertError is a helper function to check for expected errors
|
||||
func AssertError(t *testing.T, err error, expectedMsg string) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
@@ -340,7 +301,6 @@ func AssertError(t *testing.T, err error, expectedMsg string) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertEqual checks if two values are equal
|
||||
func AssertEqual(t *testing.T, expected, actual interface{}) {
|
||||
t.Helper()
|
||||
if expected != actual {
|
||||
@@ -348,7 +308,6 @@ func AssertEqual(t *testing.T, expected, actual interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertNotNil checks if a value is not nil
|
||||
func AssertNotNil(t *testing.T, value interface{}) {
|
||||
t.Helper()
|
||||
if value == nil {
|
||||
@@ -356,12 +315,9 @@ func AssertNotNil(t *testing.T, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// AssertNil checks if a value is nil
|
||||
func AssertNil(t *testing.T, value interface{}) {
|
||||
t.Helper()
|
||||
if value != nil {
|
||||
// Special handling for interface values that contain nil but aren't nil themselves
|
||||
// For example, (*jwt.Claims)(nil) is not equal to nil, but it contains nil
|
||||
switch v := value.(type) {
|
||||
case *interface{}:
|
||||
if v == nil || *v == nil {
|
||||
@@ -376,13 +332,11 @@ func AssertNil(t *testing.T, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeUTF16LEBOM encodes UTF-8 bytes to UTF-16 LE BOM format
|
||||
func EncodeUTF16LEBOM(input []byte) ([]byte, error) {
|
||||
encoder := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM)
|
||||
return transformBytes(encoder.NewEncoder(), input)
|
||||
}
|
||||
|
||||
// transformBytes applies a transform to input bytes
|
||||
func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
w := transform.NewWriter(&buf, t)
|
||||
@@ -398,7 +352,6 @@ func transformBytes(t transform.Transformer, input []byte) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// ErrorForTesting creates an error for testing purposes
|
||||
func ErrorForTesting(message string) error {
|
||||
return errors.New(message)
|
||||
}
|
||||
|
||||
27
tests/testdata/state_history_data.go
vendored
27
tests/testdata/state_history_data.go
vendored
@@ -7,13 +7,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// StateHistoryTestData provides simple test data generators
|
||||
type StateHistoryTestData struct {
|
||||
ServerID uuid.UUID
|
||||
BaseTime time.Time
|
||||
}
|
||||
|
||||
// NewStateHistoryTestData creates a new test data generator
|
||||
func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
|
||||
return &StateHistoryTestData{
|
||||
ServerID: serverID,
|
||||
@@ -21,8 +19,7 @@ func NewStateHistoryTestData(serverID uuid.UUID) *StateHistoryTestData {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateStateHistory creates a basic state history entry
|
||||
func (td *StateHistoryTestData) CreateStateHistory(session string, track string, playerCount int, sessionID uuid.UUID) model.StateHistory {
|
||||
func (td *StateHistoryTestData) CreateStateHistory(session model.TrackSession, track string, playerCount int, sessionID uuid.UUID) model.StateHistory {
|
||||
return model.StateHistory{
|
||||
ID: uuid.New(),
|
||||
ServerID: td.ServerID,
|
||||
@@ -36,8 +33,7 @@ func (td *StateHistoryTestData) CreateStateHistory(session string, track string,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateMultipleEntries creates multiple state history entries for the same session
|
||||
func (td *StateHistoryTestData) CreateMultipleEntries(session string, track string, playerCounts []int) []model.StateHistory {
|
||||
func (td *StateHistoryTestData) CreateMultipleEntries(session model.TrackSession, track string, playerCounts []int) []model.StateHistory {
|
||||
sessionID := uuid.New()
|
||||
var entries []model.StateHistory
|
||||
|
||||
@@ -59,7 +55,6 @@ func (td *StateHistoryTestData) CreateMultipleEntries(session string, track stri
|
||||
return entries
|
||||
}
|
||||
|
||||
// CreateBasicFilter creates a basic filter for testing
|
||||
func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
|
||||
return &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
@@ -68,8 +63,7 @@ func CreateBasicFilter(serverID string) *model.StateHistoryFilter {
|
||||
}
|
||||
}
|
||||
|
||||
// CreateFilterWithSession creates a filter with session type
|
||||
func CreateFilterWithSession(serverID string, session string) *model.StateHistoryFilter {
|
||||
func CreateFilterWithSession(serverID string, session model.TrackSession) *model.StateHistoryFilter {
|
||||
return &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: serverID,
|
||||
@@ -78,7 +72,6 @@ func CreateFilterWithSession(serverID string, session string) *model.StateHistor
|
||||
}
|
||||
}
|
||||
|
||||
// LogLines contains sample ACC server log lines for testing
|
||||
var SampleLogLines = []string{
|
||||
"[2024-01-15 14:30:25.123] Session changed: NONE -> PRACTICE",
|
||||
"[2024-01-15 14:30:30.456] 1 client(s) online",
|
||||
@@ -95,16 +88,14 @@ var SampleLogLines = []string{
|
||||
"[2024-01-15 15:00:05.123] Session changed: RACE -> NONE",
|
||||
}
|
||||
|
||||
// ExpectedSessionChanges represents the expected session changes from parsing the sample log lines
|
||||
var ExpectedSessionChanges = []struct {
|
||||
From string
|
||||
To string
|
||||
From model.TrackSession
|
||||
To model.TrackSession
|
||||
}{
|
||||
{"NONE", "PRACTICE"},
|
||||
{"PRACTICE", "QUALIFY"},
|
||||
{"QUALIFY", "RACE"},
|
||||
{"RACE", "NONE"},
|
||||
{model.SessionUnknown, model.SessionPractice},
|
||||
{model.SessionPractice, model.SessionQualify},
|
||||
{model.SessionQualify, model.SessionRace},
|
||||
{model.SessionRace, model.SessionUnknown},
|
||||
}
|
||||
|
||||
// ExpectedPlayerCounts represents the expected player counts from parsing the sample log lines
|
||||
var ExpectedPlayerCounts = []int{1, 3, 5, 8, 12, 15, 14, 0}
|
||||
|
||||
@@ -14,12 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func TestController_JSONParsing_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test basic JSON parsing functionality
|
||||
app := fiber.New()
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
@@ -30,7 +28,6 @@ func TestController_JSONParsing_Success(t *testing.T) {
|
||||
return c.JSON(data)
|
||||
})
|
||||
|
||||
// Prepare test data
|
||||
testData := map[string]interface{}{
|
||||
"name": "test",
|
||||
"value": 123,
|
||||
@@ -38,34 +35,28 @@ func TestController_JSONParsing_Success(t *testing.T) {
|
||||
bodyBytes, err := json.Marshal(testData)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, "test", response["name"])
|
||||
tests.AssertEqual(t, float64(123), response["value"]) // JSON numbers are float64
|
||||
tests.AssertEqual(t, float64(123), response["value"])
|
||||
}
|
||||
|
||||
func TestController_JSONParsing_InvalidJSON(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test handling of invalid JSON
|
||||
app := fiber.New()
|
||||
|
||||
app.Post("/test", func(c *fiber.Ctx) error {
|
||||
@@ -76,39 +67,31 @@ func TestController_JSONParsing_InvalidJSON(t *testing.T) {
|
||||
return c.JSON(data)
|
||||
})
|
||||
|
||||
// Create request with invalid JSON
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader([]byte("invalid json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
// Parse error response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify error response
|
||||
tests.AssertEqual(t, "Invalid JSON", response["error"])
|
||||
}
|
||||
|
||||
func TestController_UUIDValidation_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test UUID parameter validation
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test/:id", func(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
|
||||
// Validate UUID
|
||||
if _, err := uuid.Parse(id); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
|
||||
}
|
||||
@@ -116,40 +99,33 @@ func TestController_UUIDValidation_Success(t *testing.T) {
|
||||
return c.JSON(fiber.Map{"id": id, "valid": true})
|
||||
})
|
||||
|
||||
// Create request with valid UUID
|
||||
validUUID := uuid.New().String()
|
||||
req := httptest.NewRequest("GET", "/test/"+validUUID, nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, validUUID, response["id"])
|
||||
tests.AssertEqual(t, true, response["valid"])
|
||||
}
|
||||
|
||||
func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test handling of invalid UUID
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test/:id", func(c *fiber.Ctx) error {
|
||||
id := c.Params("id")
|
||||
|
||||
// Validate UUID
|
||||
if _, err := uuid.Parse(id); err != nil {
|
||||
return c.Status(400).JSON(fiber.Map{"error": "Invalid UUID"})
|
||||
}
|
||||
@@ -157,32 +133,26 @@ func TestController_UUIDValidation_InvalidUUID(t *testing.T) {
|
||||
return c.JSON(fiber.Map{"id": id, "valid": true})
|
||||
})
|
||||
|
||||
// Create request with invalid UUID
|
||||
req := httptest.NewRequest("GET", "/test/invalid-uuid", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
// Parse error response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify error response
|
||||
tests.AssertEqual(t, "Invalid UUID", response["error"])
|
||||
}
|
||||
|
||||
func TestController_QueryParameters_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test query parameter handling
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
@@ -197,34 +167,28 @@ func TestController_QueryParameters_Success(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
// Create request with query parameters
|
||||
req := httptest.NewRequest("GET", "/test?restart=true&override=false&format=xml", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, true, response["restart"])
|
||||
tests.AssertEqual(t, false, response["override"])
|
||||
tests.AssertEqual(t, "xml", response["format"])
|
||||
}
|
||||
|
||||
func TestController_HTTPMethods_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test different HTTP methods
|
||||
app := fiber.New()
|
||||
|
||||
var getCalled, postCalled, putCalled, deleteCalled bool
|
||||
@@ -249,28 +213,24 @@ func TestController_HTTPMethods_Success(t *testing.T) {
|
||||
return c.JSON(fiber.Map{"method": "DELETE"})
|
||||
})
|
||||
|
||||
// Test GET
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
tests.AssertEqual(t, true, getCalled)
|
||||
|
||||
// Test POST
|
||||
req = httptest.NewRequest("POST", "/test", nil)
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
tests.AssertEqual(t, true, postCalled)
|
||||
|
||||
// Test PUT
|
||||
req = httptest.NewRequest("PUT", "/test", nil)
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
tests.AssertEqual(t, true, putCalled)
|
||||
|
||||
// Test DELETE
|
||||
req = httptest.NewRequest("DELETE", "/test", nil)
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
@@ -279,12 +239,10 @@ func TestController_HTTPMethods_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestController_ErrorHandling_StatusCodes(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test different error status codes
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/400", func(c *fiber.Ctx) error {
|
||||
@@ -307,7 +265,6 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
|
||||
return c.Status(500).JSON(fiber.Map{"error": "Internal Server Error"})
|
||||
})
|
||||
|
||||
// Test different status codes
|
||||
testCases := []struct {
|
||||
path string
|
||||
code int
|
||||
@@ -328,12 +285,10 @@ func TestController_ErrorHandling_StatusCodes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test Configuration model JSON serialization
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/config", func(c *fiber.Ctx) error {
|
||||
@@ -348,22 +303,18 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
|
||||
return c.JSON(config)
|
||||
})
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/config", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response
|
||||
var response model.Configuration
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, model.IntString(9231), response.UdpPort)
|
||||
tests.AssertEqual(t, model.IntString(9232), response.TcpPort)
|
||||
tests.AssertEqual(t, model.IntString(30), response.MaxConnections)
|
||||
@@ -373,73 +324,61 @@ func TestController_ConfigurationModel_JSONSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestController_UserModel_JSONSerialization(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test User model JSON serialization (password should be hidden)
|
||||
app := fiber.New()
|
||||
|
||||
app.Get("/user", func(c *fiber.Ctx) error {
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
Password: "secret-password", // Should not appear in JSON
|
||||
Password: "secret-password",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
return c.JSON(user)
|
||||
})
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/user", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Parse response as raw JSON to check password is excluded
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify password field is not in JSON
|
||||
if bytes.Contains(body, []byte("password")) || bytes.Contains(body, []byte("secret-password")) {
|
||||
t.Fatal("Password should not be included in JSON response")
|
||||
}
|
||||
|
||||
// Verify other fields are present
|
||||
if !bytes.Contains(body, []byte("username")) || !bytes.Contains(body, []byte("testuser")) {
|
||||
t.Fatal("Username should be included in JSON response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_MiddlewareChaining_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Test middleware chaining
|
||||
app := fiber.New()
|
||||
|
||||
var middleware1Called, middleware2Called, handlerCalled bool
|
||||
|
||||
// Middleware 1
|
||||
middleware1 := func(c *fiber.Ctx) error {
|
||||
middleware1Called = true
|
||||
c.Locals("middleware1", "executed")
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Middleware 2
|
||||
middleware2 := func(c *fiber.Ctx) error {
|
||||
middleware2Called = true
|
||||
c.Locals("middleware2", "executed")
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Handler
|
||||
handler := func(c *fiber.Ctx) error {
|
||||
handlerCalled = true
|
||||
return c.JSON(fiber.Map{
|
||||
@@ -451,27 +390,22 @@ func TestController_MiddlewareChaining_Success(t *testing.T) {
|
||||
|
||||
app.Get("/test", middleware1, middleware2, handler)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
// Verify all were called
|
||||
tests.AssertEqual(t, true, middleware1Called)
|
||||
tests.AssertEqual(t, true, middleware2Called)
|
||||
tests.AssertEqual(t, true, handlerCalled)
|
||||
|
||||
// Parse response
|
||||
var response map[string]interface{}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
err = json.Unmarshal(body, &response)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify middleware values were passed
|
||||
tests.AssertEqual(t, "executed", response["middleware1"])
|
||||
tests.AssertEqual(t, "executed", response["middleware2"])
|
||||
tests.AssertEqual(t, "executed", response["handler"])
|
||||
|
||||
@@ -4,23 +4,27 @@ import (
|
||||
"acc-server-manager/local/middleware"
|
||||
"acc-server-manager/local/service"
|
||||
"acc-server-manager/local/utl/cache"
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
"acc-server-manager/tests"
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// MockMiddleware simulates authentication for testing purposes
|
||||
type MockMiddleware struct{}
|
||||
|
||||
// GetTestAuthMiddleware returns a mock auth middleware that can be used in place of the real one
|
||||
// This works because we're adding real authentication tokens to requests
|
||||
func GetTestAuthMiddleware(ms *service.MembershipService, cache *cache.InMemoryCache) *middleware.AuthMiddleware {
|
||||
// Cast our mock to the real type for testing
|
||||
// This is a type-unsafe cast but works for testing because we're using real JWT tokens
|
||||
return middleware.NewAuthMiddleware(ms, cache)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
|
||||
return middleware.NewAuthMiddleware(ms, cache, jwtHandler, openJWTHandler)
|
||||
}
|
||||
|
||||
// AddAuthToRequest adds a valid authentication token to a test request
|
||||
func AddAuthToRequest(req *fiber.Ctx) {
|
||||
token := tests.MustGenerateTestToken()
|
||||
req.Request().Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"acc-server-manager/local/service"
|
||||
"acc-server-manager/local/utl/cache"
|
||||
"acc-server-manager/local/utl/common"
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
"acc-server-manager/tests"
|
||||
"acc-server-manager/tests/testdata"
|
||||
"encoding/json"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -21,48 +23,45 @@ import (
|
||||
)
|
||||
|
||||
func TestStateHistoryController_GetAll_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// No need for DisableAuthentication, we'll use real auth tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -70,58 +69,55 @@ func TestStateHistoryController_GetAll_Success(t *testing.T) {
|
||||
err = json.Unmarshal(body, &result)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 1, len(result))
|
||||
tests.AssertEqual(t, "Practice", result[0].Session)
|
||||
tests.AssertEqual(t, model.SessionPractice, result[0].Session)
|
||||
tests.AssertEqual(t, 5, result[0].PlayerCount)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data with different sessions
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
practiceHistory := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
raceHistory := testData.CreateStateHistory("Race", "spa", 10, uuid.New())
|
||||
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 10, uuid.New())
|
||||
|
||||
err := repo.Insert(helper.CreateContext(), &practiceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
err = repo.Insert(helper.CreateContext(), &raceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with session filter and authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=Race", helper.TestData.ServerID.String()), nil)
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s&session=R", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -129,105 +125,98 @@ func TestStateHistoryController_GetAll_WithSessionFilter(t *testing.T) {
|
||||
err = json.Unmarshal(body, &result)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, 1, len(result))
|
||||
tests.AssertEqual(t, "Race", result[0].Session)
|
||||
tests.AssertEqual(t, model.SessionRace, result[0].Session)
|
||||
tests.AssertEqual(t, 10, result[0].PlayerCount)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetAll_EmptyResult(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with no data and authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify empty response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
// Skip this test as it requires more complex setup
|
||||
t.Skip("Skipping test due to UUID validation issues")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data with multiple entries for statistics
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Create entries with varying player counts
|
||||
playerCounts := []int{5, 10, 15, 20, 25}
|
||||
entries := testData.CreateMultipleEntries("Race", "spa", playerCounts)
|
||||
entries := testData.CreateMultipleEntries(model.SessionRace, "spa", playerCounts)
|
||||
|
||||
for _, entry := range entries {
|
||||
err := repo.Insert(helper.CreateContext(), &entry)
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with valid serverID UUID
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add Authorization header for testing
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -235,7 +224,6 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
err = json.Unmarshal(body, &stats)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify statistics structure exists (actual calculation is tested in service layer)
|
||||
if stats.PeakPlayers < 0 {
|
||||
t.Error("Expected non-negative peak players")
|
||||
}
|
||||
@@ -248,51 +236,47 @@ func TestStateHistoryController_GetStatistics_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
|
||||
// Skip this test as it requires more complex setup
|
||||
t.Skip("Skipping test due to UUID validation issues")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with valid serverID UUID
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add Authorization header for testing
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify response
|
||||
tests.AssertEqual(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Parse response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
@@ -300,105 +284,99 @@ func TestStateHistoryController_GetStatistics_NoData(t *testing.T) {
|
||||
err = json.Unmarshal(body, &stats)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify empty statistics
|
||||
tests.AssertEqual(t, 0, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
|
||||
tests.AssertEqual(t, 0, stats.TotalSessions)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_GetStatistics_InvalidQueryParams(t *testing.T) {
|
||||
// Skip this test as it requires more complex setup
|
||||
t.Skip("Skipping test due to UUID validation issues")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Create request with invalid query parameters but with valid UUID
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s&min_players=invalid", validServerID), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Add Authorization header for testing
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
|
||||
// Execute request
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify error response
|
||||
tests.AssertEqual(t, http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestStateHistoryController_HTTPMethods(t *testing.T) {
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Test that only GET method is allowed for GetAll
|
||||
req := httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
||||
// Test that only GET method is allowed for GetStatistics
|
||||
req = httptest.NewRequest("POST", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
||||
// Test that PUT method is not allowed
|
||||
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
|
||||
// Test that DELETE method is not allowed
|
||||
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
@@ -408,58 +386,55 @@ func TestStateHistoryController_HTTPMethods(t *testing.T) {
|
||||
|
||||
func TestStateHistoryController_ContentType(t *testing.T) {
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Test GetAll endpoint with authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err := app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify content type is JSON
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
||||
}
|
||||
|
||||
// Test GetStatistics endpoint with authentication
|
||||
validServerID := helper.TestData.ServerID.String()
|
||||
if validServerID == "" {
|
||||
validServerID = uuid.New().String() // Generate a new valid UUID if needed
|
||||
validServerID = uuid.New().String()
|
||||
}
|
||||
req = httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history/statistics?id=%s", validServerID), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err = app.Test(req)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify content type is JSON
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
||||
@@ -467,25 +442,27 @@ func TestStateHistoryController_ContentType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
// Skip this test as it's problematic and would require deeper investigation
|
||||
t.Skip("Skipping test due to response structure issues that need further investigation")
|
||||
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
app := fiber.New()
|
||||
// Using real JWT auth with tokens
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
stateHistoryService := service.NewStateHistoryService(repo)
|
||||
|
||||
membershipRepo := repository.NewMembershipRepository(helper.DB)
|
||||
membershipService := service.NewMembershipService(membershipRepo)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "test-secret-that-is-at-least-32-bytes-long-for-security"
|
||||
}
|
||||
jwtHandler := jwt.NewJWTHandler(jwtSecret)
|
||||
openJWTHandler := jwt.NewOpenJWTHandler(jwtSecret)
|
||||
membershipService := service.NewMembershipService(membershipRepo, jwtHandler, openJWTHandler)
|
||||
|
||||
inMemCache := cache.NewInMemoryCache()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -493,21 +470,17 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(helper.CreateContext(), &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Setup routes
|
||||
routeGroups := &common.RouteGroups{
|
||||
StateHistory: app.Group("/api/v1/state-history"),
|
||||
}
|
||||
|
||||
// Use a test auth middleware that works with the DisableAuthentication
|
||||
controller.NewStateHistoryController(stateHistoryService, routeGroups, GetTestAuthMiddleware(membershipService, inMemCache))
|
||||
|
||||
// Test GetAll response structure with authentication
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/state-history?id=%s", helper.TestData.ServerID.String()), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tests.MustGenerateTestToken())
|
||||
resp, err := app.Test(req)
|
||||
@@ -516,24 +489,19 @@ func TestStateHistoryController_ResponseStructure(t *testing.T) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Log the actual response for debugging
|
||||
t.Logf("Response body: %s", string(body))
|
||||
|
||||
// Try parsing as array first
|
||||
var resultArray []model.StateHistory
|
||||
err = json.Unmarshal(body, &resultArray)
|
||||
if err != nil {
|
||||
// If array parsing fails, try parsing as a single object
|
||||
var singleResult model.StateHistory
|
||||
err = json.Unmarshal(body, &singleResult)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse response as either array or object: %v", err)
|
||||
}
|
||||
// Convert single result to array
|
||||
resultArray = []model.StateHistory{singleResult}
|
||||
}
|
||||
|
||||
// Verify StateHistory structure
|
||||
if len(resultArray) > 0 {
|
||||
history := resultArray[0]
|
||||
if history.ID == uuid.Nil {
|
||||
|
||||
@@ -12,12 +12,10 @@ import (
|
||||
)
|
||||
|
||||
func TestStateHistoryRepository_Insert_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -28,15 +26,12 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
|
||||
// Test Insert
|
||||
err := repo.Insert(ctx, &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify ID was generated
|
||||
tests.AssertNotNil(t, history.ID)
|
||||
if history.ID == uuid.Nil {
|
||||
t.Error("Expected non-nil ID after insert")
|
||||
@@ -44,12 +39,10 @@ func TestStateHistoryRepository_Insert_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -60,19 +53,16 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Insert multiple entries
|
||||
playerCounts := []int{0, 5, 10, 15, 10, 5, 0}
|
||||
entries := testData.CreateMultipleEntries("Practice", "spa", playerCounts)
|
||||
entries := testData.CreateMultipleEntries(model.SessionPractice, "spa", playerCounts)
|
||||
|
||||
for _, entry := range entries {
|
||||
err := repo.Insert(ctx, &entry)
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Test GetAll
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
result, err := repo.GetAll(ctx, filter)
|
||||
|
||||
@@ -82,12 +72,10 @@ func TestStateHistoryRepository_GetAll_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -98,36 +86,31 @@ func TestStateHistoryRepository_GetAll_WithFilter(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data with different sessions
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
practiceHistory := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
raceHistory := testData.CreateStateHistory("Race", "spa", 15, uuid.New())
|
||||
practiceHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
raceHistory := testData.CreateStateHistory(model.SessionRace, "spa", 15, uuid.New())
|
||||
|
||||
// Insert both
|
||||
err := repo.Insert(ctx, &practiceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
err = repo.Insert(ctx, &raceHistory)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Test GetAll with session filter
|
||||
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), "Race")
|
||||
filter := testdata.CreateFilterWithSession(helper.TestData.ServerID.String(), model.SessionRace)
|
||||
result, err := repo.GetAll(ctx, filter)
|
||||
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, result)
|
||||
tests.AssertEqual(t, 1, len(*result))
|
||||
tests.AssertEqual(t, "Race", (*result)[0].Session)
|
||||
tests.AssertEqual(t, model.SessionRace, (*result)[0].Session)
|
||||
tests.AssertEqual(t, 15, (*result)[0].PlayerCount)
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -138,44 +121,34 @@ func TestStateHistoryRepository_GetLastSessionID_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Insert multiple entries with different session IDs
|
||||
sessionID1 := uuid.New()
|
||||
sessionID2 := uuid.New()
|
||||
|
||||
history1 := testData.CreateStateHistory("Practice", "spa", 5, sessionID1)
|
||||
history2 := testData.CreateStateHistory("Race", "spa", 10, sessionID2)
|
||||
history1 := testData.CreateStateHistory(model.SessionPractice, "spa", 5, sessionID1)
|
||||
history2 := testData.CreateStateHistory(model.SessionRace, "spa", 10, sessionID2)
|
||||
|
||||
// Insert with a small delay to ensure ordering
|
||||
err := repo.Insert(ctx, &history1)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
time.Sleep(1 * time.Millisecond) // Ensure different timestamps
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
|
||||
err = repo.Insert(ctx, &history2)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Test GetLastSessionID - should return the most recent session ID
|
||||
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Should be sessionID2 since it was inserted last
|
||||
// We should get the most recently inserted session ID, but the exact value doesn't matter
|
||||
// Just check that it's not nil and that it's a valid UUID
|
||||
if lastSessionID == uuid.Nil {
|
||||
t.Fatal("Expected non-nil UUID for last session ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -186,19 +159,16 @@ func TestStateHistoryRepository_GetLastSessionID_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Test GetLastSessionID with no data
|
||||
lastSessionID, err := repo.GetLastSessionID(ctx, helper.TestData.ServerID)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertEqual(t, uuid.Nil, lastSessionID)
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -209,40 +179,33 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data with varying player counts
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Create entries with different sessions and player counts
|
||||
sessionID1 := uuid.New()
|
||||
sessionID2 := uuid.New()
|
||||
|
||||
// Practice session: 5, 10, 15 players
|
||||
practiceEntries := testData.CreateMultipleEntries("Practice", "spa", []int{5, 10, 15})
|
||||
practiceEntries := testData.CreateMultipleEntries(model.SessionPractice, "spa", []int{5, 10, 15})
|
||||
for i := range practiceEntries {
|
||||
practiceEntries[i].SessionID = sessionID1
|
||||
err := repo.Insert(ctx, &practiceEntries[i])
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Race session: 20, 25, 30 players
|
||||
raceEntries := testData.CreateMultipleEntries("Race", "spa", []int{20, 25, 30})
|
||||
raceEntries := testData.CreateMultipleEntries(model.SessionRace, "spa", []int{20, 25, 30})
|
||||
for i := range raceEntries {
|
||||
raceEntries[i].SessionID = sessionID2
|
||||
err := repo.Insert(ctx, &raceEntries[i])
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Test GetSummaryStats
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
stats, err := repo.GetSummaryStats(ctx, filter)
|
||||
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify stats are calculated correctly
|
||||
tests.AssertEqual(t, 30, stats.PeakPlayers) // Maximum player count
|
||||
tests.AssertEqual(t, 2, stats.TotalSessions) // Two unique sessions
|
||||
tests.AssertEqual(t, 30, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 2, stats.TotalSessions)
|
||||
|
||||
// Average should be (5+10+15+20+25+30)/6 = 17.5
|
||||
expectedAverage := float64(5+10+15+20+25+30) / 6.0
|
||||
if stats.AveragePlayers != expectedAverage {
|
||||
t.Errorf("Expected average players %.1f, got %.1f", expectedAverage, stats.AveragePlayers)
|
||||
@@ -250,12 +213,10 @@ func TestStateHistoryRepository_GetSummaryStats_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -266,25 +227,21 @@ func TestStateHistoryRepository_GetSummaryStats_NoData(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Test GetSummaryStats with no data
|
||||
filter := testdata.CreateBasicFilter(helper.TestData.ServerID.String())
|
||||
stats, err := repo.GetSummaryStats(ctx, filter)
|
||||
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify stats are zero for empty dataset
|
||||
tests.AssertEqual(t, 0, stats.PeakPlayers)
|
||||
tests.AssertEqual(t, 0.0, stats.AveragePlayers)
|
||||
tests.AssertEqual(t, 0, stats.TotalSessions)
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
// Setup environment and test helper
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -294,18 +251,15 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data spanning a time range
|
||||
sessionID := uuid.New()
|
||||
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create entries spanning 1 hour with players > 0
|
||||
entries := []model.StateHistory{
|
||||
{
|
||||
ID: uuid.New(),
|
||||
ServerID: helper.TestData.ServerID,
|
||||
Session: "Practice",
|
||||
Session: model.SessionPractice,
|
||||
Track: "spa",
|
||||
PlayerCount: 5,
|
||||
DateCreated: baseTime,
|
||||
@@ -316,7 +270,7 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
{
|
||||
ID: uuid.New(),
|
||||
ServerID: helper.TestData.ServerID,
|
||||
Session: "Practice",
|
||||
Session: model.SessionPractice,
|
||||
Track: "spa",
|
||||
PlayerCount: 10,
|
||||
DateCreated: baseTime.Add(30 * time.Minute),
|
||||
@@ -327,7 +281,7 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
{
|
||||
ID: uuid.New(),
|
||||
ServerID: helper.TestData.ServerID,
|
||||
Session: "Practice",
|
||||
Session: model.SessionPractice,
|
||||
Track: "spa",
|
||||
PlayerCount: 8,
|
||||
DateCreated: baseTime.Add(60 * time.Minute),
|
||||
@@ -342,7 +296,6 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
// Test GetTotalPlaytime
|
||||
filter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: helper.TestData.ServerID.String(),
|
||||
@@ -355,20 +308,16 @@ func TestStateHistoryRepository_GetTotalPlaytime_Success(t *testing.T) {
|
||||
|
||||
playtime, err := repo.GetTotalPlaytime(ctx, filter)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Should calculate playtime based on session duration
|
||||
if playtime <= 0 {
|
||||
t.Error("Expected positive playtime for session with multiple entries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
// Test concurrent database operations
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -379,10 +328,8 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Create test data
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -390,8 +337,7 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create and insert initial entry to ensure table exists and is properly set up
|
||||
initialHistory := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
initialHistory := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(ctx, &initialHistory)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert initial record: %v", err)
|
||||
@@ -399,12 +345,11 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
|
||||
done := make(chan bool, 3)
|
||||
|
||||
// Concurrent inserts
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
}()
|
||||
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(ctx, &history)
|
||||
if err != nil {
|
||||
t.Logf("Insert error: %v", err)
|
||||
@@ -412,7 +357,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent reads
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
@@ -425,7 +369,6 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Concurrent GetLastSessionID
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- true
|
||||
@@ -437,19 +380,16 @@ func TestStateHistoryRepository_ConcurrentOperations(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all operations to complete
|
||||
for i := 0; i < 3; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
|
||||
// Test edge cases with filters
|
||||
tests.SetTestEnv()
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Ensure the state_histories table exists
|
||||
if !helper.DB.Migrator().HasTable(&model.StateHistory{}) {
|
||||
err := helper.DB.Migrator().CreateTable(&model.StateHistory{})
|
||||
if err != nil {
|
||||
@@ -460,15 +400,11 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
|
||||
repo := repository.NewStateHistoryRepository(helper.DB)
|
||||
ctx := helper.CreateContext()
|
||||
|
||||
// Insert a test record to ensure the table is properly set up
|
||||
testData := testdata.NewStateHistoryTestData(helper.TestData.ServerID)
|
||||
history := testData.CreateStateHistory("Practice", "spa", 5, uuid.New())
|
||||
history := testData.CreateStateHistory(model.SessionPractice, "spa", 5, uuid.New())
|
||||
err := repo.Insert(ctx, &history)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Skip nil filter test as it might not be supported by the repository implementation
|
||||
|
||||
// Test with server ID filter - this should work
|
||||
serverFilter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: helper.TestData.ServerID.String(),
|
||||
@@ -478,7 +414,6 @@ func TestStateHistoryRepository_FilterEdgeCases(t *testing.T) {
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, result)
|
||||
|
||||
// Test with invalid server ID in summary stats
|
||||
invalidFilter := &model.StateHistoryFilter{
|
||||
ServerBasedFilter: model.ServerBasedFilter{
|
||||
ServerID: "invalid-uuid",
|
||||
|
||||
@@ -5,113 +5,98 @@ import (
|
||||
"acc-server-manager/local/utl/jwt"
|
||||
"acc-server-manager/local/utl/password"
|
||||
"acc-server-manager/tests"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestJWT_GenerateAndValidateToken(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test user
|
||||
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test JWT generation
|
||||
token, err := jwt.GenerateToken(user)
|
||||
token, err := jwtHandler.GenerateToken(user.ID.String())
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, token)
|
||||
|
||||
// Verify token is not empty
|
||||
if token == "" {
|
||||
t.Fatal("Expected non-empty token, got empty string")
|
||||
}
|
||||
|
||||
// Test JWT validation
|
||||
claims, err := jwt.ValidateToken(token)
|
||||
claims, err := jwtHandler.ValidateToken(token)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, claims)
|
||||
tests.AssertEqual(t, user.ID.String(), claims.UserID)
|
||||
}
|
||||
|
||||
func TestJWT_ValidateToken_InvalidToken(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Test with invalid token
|
||||
claims, err := jwt.ValidateToken("invalid-token")
|
||||
claims, err := jwtHandler.ValidateToken("invalid-token")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid token, got nil")
|
||||
}
|
||||
// Direct nil check to avoid the interface wrapping issue
|
||||
if claims != nil {
|
||||
t.Fatalf("Expected nil claims, got %v", claims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWT_ValidateToken_EmptyToken(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
jwtHandler := jwt.NewJWTHandler(os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Test with empty token
|
||||
claims, err := jwt.ValidateToken("")
|
||||
claims, err := jwtHandler.ValidateToken("")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty token, got nil")
|
||||
}
|
||||
// Direct nil check to avoid the interface wrapping issue
|
||||
if claims != nil {
|
||||
t.Fatalf("Expected nil claims, got %v", claims)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_VerifyPassword_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Hash password manually (simulating what BeforeCreate would do)
|
||||
plainPassword := "password123"
|
||||
hashedPassword, err := password.HashPassword(plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
user.Password = hashedPassword
|
||||
|
||||
// Test password verification - should succeed
|
||||
err = user.VerifyPassword(plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
func TestUser_VerifyPassword_Failure(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create test user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Hash password manually
|
||||
hashedPassword, err := password.HashPassword("correct_password")
|
||||
tests.AssertNoError(t, err)
|
||||
user.Password = hashedPassword
|
||||
|
||||
// Test password verification with wrong password - should fail
|
||||
err = user.VerifyPassword("wrong_password")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for wrong password, got nil")
|
||||
@@ -119,11 +104,9 @@ func TestUser_VerifyPassword_Failure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_Validate_Success(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create valid user
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
@@ -131,25 +114,21 @@ func TestUser_Validate_Success(t *testing.T) {
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test validation - should succeed
|
||||
err := user.Validate()
|
||||
tests.AssertNoError(t, err)
|
||||
}
|
||||
|
||||
func TestUser_Validate_MissingUsername(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create user without username
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "", // Missing username
|
||||
Username: "",
|
||||
Password: "password123",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test validation - should fail
|
||||
err := user.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing username, got nil")
|
||||
@@ -157,19 +136,16 @@ func TestUser_Validate_MissingUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUser_Validate_MissingPassword(t *testing.T) {
|
||||
// Setup
|
||||
helper := tests.NewTestHelper(t)
|
||||
defer helper.Cleanup()
|
||||
|
||||
// Create user without password
|
||||
user := &model.User{
|
||||
ID: uuid.New(),
|
||||
Username: "testuser",
|
||||
Password: "", // Missing password
|
||||
Password: "",
|
||||
RoleID: uuid.New(),
|
||||
}
|
||||
|
||||
// Test validation - should fail
|
||||
err := user.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing password, got nil")
|
||||
@@ -177,24 +153,19 @@ func TestUser_Validate_MissingPassword(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPassword_HashAndVerify(t *testing.T) {
|
||||
// Test password hashing and verification directly
|
||||
plainPassword := "test_password_123"
|
||||
|
||||
// Hash password
|
||||
hashedPassword, err := password.HashPassword(plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
tests.AssertNotNil(t, hashedPassword)
|
||||
|
||||
// Verify hashed password is not the same as plain password
|
||||
if hashedPassword == plainPassword {
|
||||
t.Fatal("Hashed password should not equal plain password")
|
||||
}
|
||||
|
||||
// Verify correct password
|
||||
err = password.VerifyPassword(hashedPassword, plainPassword)
|
||||
tests.AssertNoError(t, err)
|
||||
|
||||
// Verify wrong password fails
|
||||
err = password.VerifyPassword(hashedPassword, "wrong_password")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for wrong password, got nil")
|
||||
@@ -210,7 +181,7 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
|
||||
{"Valid password", "StrongPassword123!", false},
|
||||
{"Too short", "123", true},
|
||||
{"Empty password", "", true},
|
||||
{"Medium password", "password123", false}, // Depends on validation rules
|
||||
{"Medium password", "password123", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -230,7 +201,6 @@ func TestPassword_ValidatePasswordStrength(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRole_Model(t *testing.T) {
|
||||
// Test Role model structure
|
||||
permissions := []model.Permission{
|
||||
{ID: uuid.New(), Name: "read"},
|
||||
{ID: uuid.New(), Name: "write"},
|
||||
@@ -243,7 +213,6 @@ func TestRole_Model(t *testing.T) {
|
||||
Permissions: permissions,
|
||||
}
|
||||
|
||||
// Verify role structure
|
||||
tests.AssertEqual(t, "Test Role", role.Name)
|
||||
tests.AssertEqual(t, 3, len(role.Permissions))
|
||||
tests.AssertEqual(t, "read", role.Permissions[0].Name)
|
||||
@@ -252,19 +221,16 @@ func TestRole_Model(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPermission_Model(t *testing.T) {
|
||||
// Test Permission model structure
|
||||
permission := &model.Permission{
|
||||
ID: uuid.New(),
|
||||
Name: "test_permission",
|
||||
}
|
||||
|
||||
// Verify permission structure
|
||||
tests.AssertEqual(t, "test_permission", permission.Name)
|
||||
tests.AssertNotNil(t, permission.ID)
|
||||
}
|
||||
|
||||
func TestUser_WithRole_Model(t *testing.T) {
|
||||
// Test User model with Role relationship
|
||||
permissions := []model.Permission{
|
||||
{ID: uuid.New(), Name: "read"},
|
||||
{ID: uuid.New(), Name: "write"},
|
||||
@@ -284,7 +250,6 @@ func TestUser_WithRole_Model(t *testing.T) {
|
||||
Role: role,
|
||||
}
|
||||
|
||||
// Verify user-role relationship
|
||||
tests.AssertEqual(t, "testuser", user.Username)
|
||||
tests.AssertEqual(t, role.ID, user.RoleID)
|
||||
tests.AssertEqual(t, "User", user.Role.Name)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user