init bootstrap
This commit is contained in:
102
local/utl/cache/cache.go
vendored
Normal file
102
local/utl/cache/cache.go
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"omega-server/local/utl/logging"
|
||||
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
Expiration int64
|
||||
}
|
||||
|
||||
// InMemoryCache is a thread-safe in-memory cache
|
||||
type InMemoryCache struct {
|
||||
items map[string]CacheItem
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewInMemoryCache creates and returns a new InMemoryCache instance
|
||||
func NewInMemoryCache() *InMemoryCache {
|
||||
return &InMemoryCache{
|
||||
items: make(map[string]CacheItem),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache with an expiration duration (in seconds)
|
||||
func (c *InMemoryCache) Set(key string, value interface{}, duration time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var expiration int64
|
||||
if duration > 0 {
|
||||
expiration = time.Now().Add(duration).UnixNano()
|
||||
}
|
||||
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
Expiration: expiration,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *InMemoryCache) Get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
item, found := c.items[key]
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if item.Expiration > 0 && time.Now().UnixNano() > item.Expiration {
|
||||
// Item has expired, but don't delete here to avoid lock upgrade.
|
||||
// It will be overwritten on the next Set.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
func (c *InMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// GetOrSet retrieves an item from the cache. If the item is not found, it
|
||||
// calls the provided function to get the value, sets it in the cache, and
|
||||
// returns it.
|
||||
func GetOrSet[T any](c *InMemoryCache, key string, duration time.Duration, fetcher func() (T, error)) (T, error) {
|
||||
if cached, found := c.Get(key); found {
|
||||
if value, ok := cached.(T); ok {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
value, err := fetcher()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.Set(key, value, duration)
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Start initializes the cache and provides it to the DI container.
|
||||
func Start(di *dig.Container) {
|
||||
cache := NewInMemoryCache()
|
||||
err := di.Provide(func() *InMemoryCache {
|
||||
return cache
|
||||
})
|
||||
if err != nil {
|
||||
logging.Panic("failed to provide cache")
|
||||
}
|
||||
}
|
||||
311
local/utl/common/types.go
Normal file
311
local/utl/common/types.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// RouteGroups holds the different route groups for API organization
|
||||
type RouteGroups struct {
|
||||
API fiber.Router
|
||||
Auth fiber.Router
|
||||
Users fiber.Router
|
||||
Roles fiber.Router
|
||||
System fiber.Router
|
||||
Admin fiber.Router
|
||||
}
|
||||
|
||||
// HTTPError represents a structured HTTP error
|
||||
type HTTPError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e HTTPError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// NewHTTPError creates a new HTTP error
|
||||
func NewHTTPError(code int, message string) HTTPError {
|
||||
return HTTPError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPErrorWithDetails creates a new HTTP error with details
|
||||
func NewHTTPErrorWithDetails(code int, message string, details map[string]interface{}) HTTPError {
|
||||
return HTTPError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidationError represents a validation error
|
||||
type ValidationError struct {
|
||||
Field string `json:"field"`
|
||||
Message string `json:"message"`
|
||||
Value string `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// ValidationErrors represents multiple validation errors
|
||||
type ValidationErrors []ValidationError
|
||||
|
||||
// Error implements the error interface
|
||||
func (ve ValidationErrors) Error() string {
|
||||
if len(ve) == 0 {
|
||||
return "validation failed"
|
||||
}
|
||||
return ve[0].Message
|
||||
}
|
||||
|
||||
// APIResponse represents a standard API response structure
|
||||
type APIResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Code int `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// SuccessResponse creates a success response
|
||||
func SuccessResponse(data interface{}, message ...string) APIResponse {
|
||||
response := APIResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
}
|
||||
if len(message) > 0 {
|
||||
response.Message = message[0]
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
// ErrorResponse creates an error response
|
||||
func ErrorResponse(code int, message string) APIResponse {
|
||||
return APIResponse{
|
||||
Success: false,
|
||||
Error: message,
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
||||
// PaginationRequest represents pagination parameters
|
||||
type PaginationRequest struct {
|
||||
Page int `query:"page" validate:"min=1"`
|
||||
Limit int `query:"limit" validate:"min=1,max=100"`
|
||||
Sort string `query:"sort"`
|
||||
Order string `query:"order" validate:"oneof=asc desc"`
|
||||
Search string `query:"search"`
|
||||
Filter string `query:"filter"`
|
||||
Category string `query:"category"`
|
||||
}
|
||||
|
||||
// DefaultPagination returns default pagination values
|
||||
func DefaultPagination() PaginationRequest {
|
||||
return PaginationRequest{
|
||||
Page: 1,
|
||||
Limit: 10,
|
||||
Sort: "dateCreated",
|
||||
Order: "desc",
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates pagination parameters
|
||||
func (p *PaginationRequest) Validate() {
|
||||
if p.Page < 1 {
|
||||
p.Page = 1
|
||||
}
|
||||
if p.Limit < 1 || p.Limit > 100 {
|
||||
p.Limit = 10
|
||||
}
|
||||
if p.Sort == "" {
|
||||
p.Sort = "dateCreated"
|
||||
}
|
||||
if p.Order != "asc" && p.Order != "desc" {
|
||||
p.Order = "desc"
|
||||
}
|
||||
}
|
||||
|
||||
// Offset calculates the offset for database queries
|
||||
func (p *PaginationRequest) Offset() int {
|
||||
return (p.Page - 1) * p.Limit
|
||||
}
|
||||
|
||||
// PaginationResponse represents paginated response metadata
|
||||
type PaginationResponse struct {
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
Total int64 `json:"total"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
HasNext bool `json:"hasNext"`
|
||||
HasPrevious bool `json:"hasPrevious"`
|
||||
NextPage *int `json:"nextPage,omitempty"`
|
||||
PreviousPage *int `json:"previousPage,omitempty"`
|
||||
}
|
||||
|
||||
// NewPaginationResponse creates a new pagination response
|
||||
func NewPaginationResponse(page, limit int, total int64) PaginationResponse {
|
||||
totalPages := int((total + int64(limit) - 1) / int64(limit))
|
||||
|
||||
response := PaginationResponse{
|
||||
Page: page,
|
||||
Limit: limit,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
HasNext: page < totalPages,
|
||||
HasPrevious: page > 1,
|
||||
}
|
||||
|
||||
if response.HasNext {
|
||||
next := page + 1
|
||||
response.NextPage = &next
|
||||
}
|
||||
|
||||
if response.HasPrevious {
|
||||
prev := page - 1
|
||||
response.PreviousPage = &prev
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// PaginatedResponse represents a paginated API response
|
||||
type PaginatedResponse struct {
|
||||
APIResponse
|
||||
Pagination PaginationResponse `json:"pagination"`
|
||||
}
|
||||
|
||||
// NewPaginatedResponse creates a new paginated response
|
||||
func NewPaginatedResponse(data interface{}, pagination PaginationResponse, message ...string) PaginatedResponse {
|
||||
response := PaginatedResponse{
|
||||
APIResponse: SuccessResponse(data, message...),
|
||||
Pagination: pagination,
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
// RequestContext represents the context of an HTTP request
|
||||
type RequestContext struct {
|
||||
UserID string
|
||||
Email string
|
||||
Username string
|
||||
Roles []string
|
||||
IP string
|
||||
UserAgent string
|
||||
RequestID string
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// GetUserInfo returns user information from context
|
||||
func (rc *RequestContext) GetUserInfo() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"userId": rc.UserID,
|
||||
"email": rc.Email,
|
||||
"username": rc.Username,
|
||||
"roles": rc.Roles,
|
||||
}
|
||||
}
|
||||
|
||||
// HasRole checks if the user has a specific role
|
||||
func (rc *RequestContext) HasRole(role string) bool {
|
||||
for _, r := range rc.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user has admin role
|
||||
func (rc *RequestContext) IsAdmin() bool {
|
||||
return rc.HasRole("admin")
|
||||
}
|
||||
|
||||
// SortDirection represents sort direction
|
||||
type SortDirection string
|
||||
|
||||
const (
|
||||
SortAsc SortDirection = "asc"
|
||||
SortDesc SortDirection = "desc"
|
||||
)
|
||||
|
||||
// SortField represents a field to sort by
|
||||
type SortField struct {
|
||||
Field string `json:"field"`
|
||||
Direction SortDirection `json:"direction"`
|
||||
}
|
||||
|
||||
// FilterOperator represents filter operators
|
||||
type FilterOperator string
|
||||
|
||||
const (
|
||||
FilterEqual FilterOperator = "eq"
|
||||
FilterNotEqual FilterOperator = "ne"
|
||||
FilterGreaterThan FilterOperator = "gt"
|
||||
FilterGreaterEqual FilterOperator = "gte"
|
||||
FilterLessThan FilterOperator = "lt"
|
||||
FilterLessEqual FilterOperator = "lte"
|
||||
FilterLike FilterOperator = "like"
|
||||
FilterIn FilterOperator = "in"
|
||||
FilterNotIn FilterOperator = "nin"
|
||||
FilterIsNull FilterOperator = "isnull"
|
||||
FilterIsNotNull FilterOperator = "isnotnull"
|
||||
)
|
||||
|
||||
// FilterCondition represents a filter condition
|
||||
type FilterCondition struct {
|
||||
Field string `json:"field"`
|
||||
Operator FilterOperator `json:"operator"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
// QueryOptions represents query options for database operations
|
||||
type QueryOptions struct {
|
||||
Filters []FilterCondition `json:"filters,omitempty"`
|
||||
Sort []SortField `json:"sort,omitempty"`
|
||||
Pagination *PaginationRequest `json:"pagination,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
Exclude []string `json:"exclude,omitempty"`
|
||||
}
|
||||
|
||||
// BulkOperation represents a bulk operation request
|
||||
type BulkOperation struct {
|
||||
Action string `json:"action" validate:"required"`
|
||||
IDs []string `json:"ids" validate:"required,min=1"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// AuditInfo represents audit information for operations
|
||||
type AuditInfo struct {
|
||||
Action string `json:"action"`
|
||||
Resource string `json:"resource"`
|
||||
ResourceID string `json:"resourceId"`
|
||||
UserID string `json:"userId"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HealthStatus represents system health status
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
Environment string `json:"environment"`
|
||||
Uptime string `json:"uptime"`
|
||||
Database map[string]interface{} `json:"database"`
|
||||
Cache map[string]interface{} `json:"cache"`
|
||||
Memory map[string]interface{} `json:"memory"`
|
||||
}
|
||||
|
||||
// SystemStats represents system statistics
|
||||
type SystemStats struct {
|
||||
TotalUsers int64 `json:"totalUsers"`
|
||||
ActiveUsers int64 `json:"activeUsers"`
|
||||
TotalRoles int64 `json:"totalRoles"`
|
||||
TotalRequests int64 `json:"totalRequests"`
|
||||
ErrorRate float64 `json:"errorRate"`
|
||||
ResponseTime float64 `json:"avgResponseTime"`
|
||||
}
|
||||
220
local/utl/configs/config.go
Normal file
220
local/utl/configs/config.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// API configuration constants
|
||||
const (
|
||||
// Prefix for all API routes
|
||||
Prefix = "/api/v1"
|
||||
|
||||
// Default values
|
||||
DefaultPort = "3000"
|
||||
DefaultDatabaseName = "app.db"
|
||||
DefaultJWTExpiryHours = 24
|
||||
DefaultPasswordMinLen = 8
|
||||
DefaultMaxLoginAttempts = 5
|
||||
DefaultLockoutMinutes = 30
|
||||
DefaultRateLimitReqs = 100
|
||||
DefaultRateLimitWindow = 1 // minute
|
||||
)
|
||||
|
||||
// Environment variable keys
|
||||
const (
|
||||
EnvPort = "PORT"
|
||||
EnvDatabaseName = "DB_NAME"
|
||||
EnvJWTSecret = "JWT_SECRET"
|
||||
EnvAppSecret = "APP_SECRET"
|
||||
EnvAppSecretCode = "APP_SECRET_CODE"
|
||||
EnvEncryptionKey = "ENCRYPTION_KEY"
|
||||
EnvCORSAllowedOrigin = "CORS_ALLOWED_ORIGIN"
|
||||
EnvLogLevel = "LOG_LEVEL"
|
||||
EnvDebugMode = "DEBUG_MODE"
|
||||
EnvDefaultAdminPassword = "DEFAULT_ADMIN_PASSWORD"
|
||||
EnvJWTAccessTTLHours = "JWT_ACCESS_TTL_HOURS"
|
||||
EnvJWTRefreshTTLDays = "JWT_REFRESH_TTL_DAYS"
|
||||
EnvJWTIssuer = "JWT_ISSUER"
|
||||
)
|
||||
|
||||
// GetEnv returns environment variable value or default
|
||||
func GetEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetEnvInt returns environment variable as integer or default
|
||||
func GetEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetEnvBool returns environment variable as boolean or default
|
||||
func GetEnvBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if boolValue, err := strconv.ParseBool(value); err == nil {
|
||||
return boolValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetEnvDuration returns environment variable as duration or default
|
||||
func GetEnvDuration(key string, defaultValue time.Duration) time.Duration {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if duration, err := time.ParseDuration(value); err == nil {
|
||||
return duration
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// Config holds application configuration
|
||||
type Config struct {
|
||||
// Server configuration
|
||||
Port string
|
||||
DatabaseName string
|
||||
LogLevel string
|
||||
DebugMode bool
|
||||
|
||||
// Security configuration
|
||||
JWTSecret string
|
||||
AppSecret string
|
||||
AppSecretCode string
|
||||
EncryptionKey string
|
||||
JWTAccessTTLHours int
|
||||
JWTRefreshTTLDays int
|
||||
JWTIssuer string
|
||||
PasswordMinLength int
|
||||
MaxLoginAttempts int
|
||||
LockoutDurationMins int
|
||||
|
||||
// CORS configuration
|
||||
CORSAllowedOrigin string
|
||||
|
||||
// Rate limiting
|
||||
RateLimitRequests int
|
||||
RateLimitWindow int // minutes
|
||||
|
||||
// Admin configuration
|
||||
DefaultAdminPassword string
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from environment variables
|
||||
func LoadConfig() *Config {
|
||||
return &Config{
|
||||
// Server configuration
|
||||
Port: GetEnv(EnvPort, DefaultPort),
|
||||
DatabaseName: GetEnv(EnvDatabaseName, DefaultDatabaseName),
|
||||
LogLevel: GetEnv(EnvLogLevel, "INFO"),
|
||||
DebugMode: GetEnvBool(EnvDebugMode, false),
|
||||
|
||||
// Security configuration
|
||||
JWTSecret: GetEnv(EnvJWTSecret, ""),
|
||||
AppSecret: GetEnv(EnvAppSecret, ""),
|
||||
AppSecretCode: GetEnv(EnvAppSecretCode, ""),
|
||||
EncryptionKey: GetEnv(EnvEncryptionKey, ""),
|
||||
JWTAccessTTLHours: GetEnvInt(EnvJWTAccessTTLHours, DefaultJWTExpiryHours),
|
||||
JWTRefreshTTLDays: GetEnvInt(EnvJWTRefreshTTLDays, 7),
|
||||
JWTIssuer: GetEnv(EnvJWTIssuer, "omega-server"),
|
||||
PasswordMinLength: GetEnvInt("PASSWORD_MIN_LENGTH", DefaultPasswordMinLen),
|
||||
MaxLoginAttempts: GetEnvInt("MAX_LOGIN_ATTEMPTS", DefaultMaxLoginAttempts),
|
||||
LockoutDurationMins: GetEnvInt("LOCKOUT_DURATION_MINUTES", DefaultLockoutMinutes),
|
||||
|
||||
// CORS configuration
|
||||
CORSAllowedOrigin: GetEnv(EnvCORSAllowedOrigin, "http://localhost:5173"),
|
||||
|
||||
// Rate limiting
|
||||
RateLimitRequests: GetEnvInt("RATE_LIMIT_REQUESTS", DefaultRateLimitReqs),
|
||||
RateLimitWindow: GetEnvInt("RATE_LIMIT_WINDOW_MINUTES", DefaultRateLimitWindow),
|
||||
|
||||
// Admin configuration
|
||||
DefaultAdminPassword: GetEnv(EnvDefaultAdminPassword, ""),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *Config) Validate() []string {
|
||||
var errors []string
|
||||
|
||||
// Required security settings
|
||||
if c.JWTSecret == "" {
|
||||
errors = append(errors, "JWT_SECRET is required")
|
||||
}
|
||||
if c.AppSecret == "" {
|
||||
errors = append(errors, "APP_SECRET is required")
|
||||
}
|
||||
if c.AppSecretCode == "" {
|
||||
errors = append(errors, "APP_SECRET_CODE is required")
|
||||
}
|
||||
if c.EncryptionKey == "" {
|
||||
errors = append(errors, "ENCRYPTION_KEY is required")
|
||||
}
|
||||
|
||||
// Validate encryption key length (must be 32 characters for AES-256)
|
||||
if len(c.EncryptionKey) != 32 {
|
||||
errors = append(errors, "ENCRYPTION_KEY must be exactly 32 characters long")
|
||||
}
|
||||
|
||||
// Validate JWT settings
|
||||
if c.JWTAccessTTLHours <= 0 {
|
||||
errors = append(errors, "JWT_ACCESS_TTL_HOURS must be greater than 0")
|
||||
}
|
||||
if c.JWTRefreshTTLDays <= 0 {
|
||||
errors = append(errors, "JWT_REFRESH_TTL_DAYS must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate password settings
|
||||
if c.PasswordMinLength < 8 {
|
||||
errors = append(errors, "PASSWORD_MIN_LENGTH must be at least 8")
|
||||
}
|
||||
|
||||
// Validate rate limiting
|
||||
if c.RateLimitRequests <= 0 {
|
||||
errors = append(errors, "RATE_LIMIT_REQUESTS must be greater than 0")
|
||||
}
|
||||
if c.RateLimitWindow <= 0 {
|
||||
errors = append(errors, "RATE_LIMIT_WINDOW_MINUTES must be greater than 0")
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// IsProduction returns true if running in production mode
|
||||
func (c *Config) IsProduction() bool {
|
||||
env := GetEnv("GO_ENV", "development")
|
||||
return env == "production"
|
||||
}
|
||||
|
||||
// IsDevelopment returns true if running in development mode
|
||||
func (c *Config) IsDevelopment() bool {
|
||||
return !c.IsProduction()
|
||||
}
|
||||
|
||||
// GetJWTAccessTTL returns JWT access token TTL as duration
|
||||
func (c *Config) GetJWTAccessTTL() time.Duration {
|
||||
return time.Duration(c.JWTAccessTTLHours) * time.Hour
|
||||
}
|
||||
|
||||
// GetJWTRefreshTTL returns JWT refresh token TTL as duration
|
||||
func (c *Config) GetJWTRefreshTTL() time.Duration {
|
||||
return time.Duration(c.JWTRefreshTTLDays) * 24 * time.Hour
|
||||
}
|
||||
|
||||
// GetLockoutDuration returns account lockout duration
|
||||
func (c *Config) GetLockoutDuration() time.Duration {
|
||||
return time.Duration(c.LockoutDurationMins) * time.Minute
|
||||
}
|
||||
|
||||
// GetRateLimitWindow returns rate limit window as duration
|
||||
func (c *Config) GetRateLimitWindow() time.Duration {
|
||||
return time.Duration(c.RateLimitWindow) * time.Minute
|
||||
}
|
||||
320
local/utl/db/db.go
Normal file
320
local/utl/db/db.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"omega-server/local/model"
|
||||
"omega-server/local/utl/logging"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"go.uber.org/dig"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
func Start(di *dig.Container) {
|
||||
dbName := os.Getenv("DB_NAME")
|
||||
if dbName == "" {
|
||||
dbName = "app.db"
|
||||
}
|
||||
|
||||
// Configure GORM logger
|
||||
gormLogger := logger.Default
|
||||
if os.Getenv("LOG_LEVEL") == "DEBUG" {
|
||||
gormLogger = logger.Default.LogMode(logger.Info)
|
||||
} else {
|
||||
gormLogger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
logging.Panic("failed to connect database: " + err.Error())
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
logging.Panic("failed to get database instance: " + err.Error())
|
||||
}
|
||||
|
||||
// Set connection pool settings
|
||||
sqlDB.SetMaxIdleConns(10)
|
||||
sqlDB.SetMaxOpenConns(100)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
err = di.Provide(func() *gorm.DB {
|
||||
return db
|
||||
})
|
||||
if err != nil {
|
||||
logging.Panic("failed to bind database: " + err.Error())
|
||||
}
|
||||
|
||||
logging.Info("Database connected successfully")
|
||||
Migrate(db)
|
||||
}
|
||||
|
||||
func Migrate(db *gorm.DB) {
|
||||
logging.Info("Starting database migration...")
|
||||
|
||||
err := db.AutoMigrate(
|
||||
&model.User{},
|
||||
&model.Role{},
|
||||
&model.Permission{},
|
||||
&model.SystemConfig{},
|
||||
&model.AuditLog{},
|
||||
&model.SecurityEvent{},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logging.Panic("failed to migrate database models: " + err.Error())
|
||||
}
|
||||
|
||||
logging.Info("Database migration completed successfully")
|
||||
Seed(db)
|
||||
}
|
||||
|
||||
func Seed(db *gorm.DB) error {
|
||||
logging.Info("Starting database seeding...")
|
||||
|
||||
if err := seedRoles(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := seedPermissions(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := seedDefaultAdmin(db); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := seedSystemConfigs(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.Info("Database seeding completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedRoles(db *gorm.DB) error {
|
||||
roles := []model.Role{
|
||||
{Name: "admin", Description: "Administrator with full access"},
|
||||
{Name: "user", Description: "Regular user with limited access"},
|
||||
{Name: "viewer", Description: "Read-only access"},
|
||||
}
|
||||
|
||||
for _, role := range roles {
|
||||
var existingRole model.Role
|
||||
err := db.Where("name = ?", role.Name).First(&existingRole).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
role.Init()
|
||||
if err := db.Create(&role).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logging.Info("Created role: %s", role.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedPermissions(db *gorm.DB) error {
|
||||
permissions := []model.Permission{
|
||||
{Name: "user:create", Description: "Create new users"},
|
||||
{Name: "user:read", Description: "Read user information"},
|
||||
{Name: "user:update", Description: "Update user information"},
|
||||
{Name: "user:delete", Description: "Delete users"},
|
||||
{Name: "role:create", Description: "Create new roles"},
|
||||
{Name: "role:read", Description: "Read role information"},
|
||||
{Name: "role:update", Description: "Update role information"},
|
||||
{Name: "role:delete", Description: "Delete roles"},
|
||||
{Name: "system:config", Description: "Access system configuration"},
|
||||
{Name: "system:logs", Description: "Access system logs"},
|
||||
{Name: "system:admin", Description: "Full system administration"},
|
||||
}
|
||||
|
||||
for _, permission := range permissions {
|
||||
var existingPermission model.Permission
|
||||
err := db.Where("name = ?", permission.Name).First(&existingPermission).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
permission.Init()
|
||||
if err := db.Create(&permission).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logging.Info("Created permission: %s", permission.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign all permissions to admin role
|
||||
var adminRole model.Role
|
||||
if err := db.Where("name = ?", "admin").First(&adminRole).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var allPermissions []model.Permission
|
||||
if err := db.Find(&allPermissions).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.Model(&adminRole).Association("Permissions").Replace(allPermissions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign basic permissions to user role
|
||||
var userRole model.Role
|
||||
if err := db.Where("name = ?", "user").First(&userRole).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var userPermissions []model.Permission
|
||||
userPermissionNames := []string{"user:read", "role:read"}
|
||||
if err := db.Where("name IN ?", userPermissionNames).Find(&userPermissions).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.Model(&userRole).Association("Permissions").Replace(userPermissions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign read permissions to viewer role
|
||||
var viewerRole model.Role
|
||||
if err := db.Where("name = ?", "viewer").First(&viewerRole).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var viewerPermissions []model.Permission
|
||||
viewerPermissionNames := []string{"user:read", "role:read"}
|
||||
if err := db.Where("name IN ?", viewerPermissionNames).Find(&viewerPermissions).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.Model(&viewerRole).Association("Permissions").Replace(viewerPermissions); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedDefaultAdmin(db *gorm.DB) error {
|
||||
// Check if admin user already exists
|
||||
var existingAdmin model.User
|
||||
err := db.Where("email = ?", "admin@example.com").First(&existingAdmin).Error
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil // Admin already exists or other error
|
||||
}
|
||||
|
||||
// Get admin role
|
||||
var adminRole model.Role
|
||||
if err := db.Where("name = ?", "admin").First(&adminRole).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create default admin user
|
||||
defaultPassword := os.Getenv("DEFAULT_ADMIN_PASSWORD")
|
||||
if defaultPassword == "" {
|
||||
defaultPassword = "admin123"
|
||||
}
|
||||
|
||||
admin := model.User{
|
||||
Email: "admin@example.com",
|
||||
Username: "admin",
|
||||
Name: "System Administrator",
|
||||
Active: true,
|
||||
}
|
||||
admin.Init()
|
||||
|
||||
if err := admin.SetPassword(defaultPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.Create(&admin).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign admin role
|
||||
if err := db.Model(&admin).Association("Roles").Append(&adminRole); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.Info("Created default admin user with email: %s", admin.Email)
|
||||
logging.Warn("Default admin password is: %s - PLEASE CHANGE THIS IMMEDIATELY!", defaultPassword)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func seedSystemConfigs(db *gorm.DB) error {
|
||||
configs := []model.SystemConfig{
|
||||
{
|
||||
Key: "app.name",
|
||||
Value: "Bootstrap App",
|
||||
DefaultValue: "Bootstrap App",
|
||||
Description: "Application name",
|
||||
Category: "general",
|
||||
DataType: "string",
|
||||
IsEditable: true,
|
||||
DateModified: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
Key: "app.version",
|
||||
Value: "1.0.0",
|
||||
DefaultValue: "1.0.0",
|
||||
Description: "Application version",
|
||||
Category: "general",
|
||||
DataType: "string",
|
||||
IsEditable: false,
|
||||
DateModified: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
Key: "security.jwt_expiry_hours",
|
||||
Value: "24",
|
||||
DefaultValue: "24",
|
||||
Description: "JWT token expiry time in hours",
|
||||
Category: "security",
|
||||
DataType: "integer",
|
||||
IsEditable: true,
|
||||
DateModified: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
Key: "security.password_min_length",
|
||||
Value: "8",
|
||||
DefaultValue: "8",
|
||||
Description: "Minimum password length",
|
||||
Category: "security",
|
||||
DataType: "integer",
|
||||
IsEditable: true,
|
||||
DateModified: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
Key: "security.max_login_attempts",
|
||||
Value: "5",
|
||||
DefaultValue: "5",
|
||||
Description: "Maximum login attempts before lockout",
|
||||
Category: "security",
|
||||
DataType: "integer",
|
||||
IsEditable: true,
|
||||
DateModified: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
Key: "security.lockout_duration_minutes",
|
||||
Value: "30",
|
||||
DefaultValue: "30",
|
||||
Description: "Account lockout duration in minutes",
|
||||
Category: "security",
|
||||
DataType: "integer",
|
||||
IsEditable: true,
|
||||
DateModified: time.Now().UTC().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
var existingConfig model.SystemConfig
|
||||
err := db.Where("key = ?", config.Key).First(&existingConfig).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
config.Init()
|
||||
if err := db.Create(&config).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
logging.Info("Created system config: %s", config.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
184
local/utl/error_handler/controller_error_handler.go
Normal file
184
local/utl/error_handler/controller_error_handler.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package error_handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"omega-server/local/utl/logging"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// ControllerErrorHandler provides centralized error handling for controllers
|
||||
type ControllerErrorHandler struct {
|
||||
errorLogger *logging.ErrorLogger
|
||||
}
|
||||
|
||||
// NewControllerErrorHandler creates a new controller error handler instance
|
||||
func NewControllerErrorHandler() *ControllerErrorHandler {
|
||||
return &ControllerErrorHandler{
|
||||
errorLogger: logging.GetErrorLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResponse represents a standardized error response
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code int `json:"code,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// HandleError handles controller errors with logging and standardized responses
|
||||
func (ceh *ControllerErrorHandler) HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get caller information for logging
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
file = strings.TrimPrefix(file, "acc-server-manager/")
|
||||
|
||||
// Build context string
|
||||
contextStr := ""
|
||||
if len(context) > 0 {
|
||||
contextStr = fmt.Sprintf("[%s] ", strings.Join(context, "|"))
|
||||
}
|
||||
|
||||
// Clean error message (remove null bytes)
|
||||
cleanErrorMsg := strings.ReplaceAll(err.Error(), "\x00", "")
|
||||
|
||||
// Log the error with context
|
||||
ceh.errorLogger.LogWithContext(
|
||||
fmt.Sprintf("CONTROLLER_ERROR [%s:%d]", file, line),
|
||||
"%s%s",
|
||||
contextStr,
|
||||
cleanErrorMsg,
|
||||
)
|
||||
|
||||
// Create standardized error response
|
||||
errorResponse := ErrorResponse{
|
||||
Error: cleanErrorMsg,
|
||||
Code: statusCode,
|
||||
}
|
||||
|
||||
// Add request details if available
|
||||
if c != nil {
|
||||
if errorResponse.Details == nil {
|
||||
errorResponse.Details = make(map[string]string)
|
||||
}
|
||||
errorResponse.Details["method"] = c.Method()
|
||||
errorResponse.Details["path"] = c.Path()
|
||||
errorResponse.Details["ip"] = c.IP()
|
||||
}
|
||||
|
||||
// Return appropriate response based on status code
|
||||
if statusCode >= 500 {
|
||||
// For server errors, don't expose internal details
|
||||
return c.Status(statusCode).JSON(ErrorResponse{
|
||||
Error: "Internal server error",
|
||||
Code: statusCode,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(statusCode).JSON(errorResponse)
|
||||
}
|
||||
|
||||
// HandleValidationError handles validation errors specifically
|
||||
func (ceh *ControllerErrorHandler) HandleValidationError(c *fiber.Ctx, err error, field string) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "VALIDATION", field)
|
||||
}
|
||||
|
||||
// HandleDatabaseError handles database-related errors
|
||||
func (ceh *ControllerErrorHandler) HandleDatabaseError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "DATABASE")
|
||||
}
|
||||
|
||||
// HandleAuthError handles authentication/authorization errors
|
||||
func (ceh *ControllerErrorHandler) HandleAuthError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusUnauthorized, "AUTH")
|
||||
}
|
||||
|
||||
// HandleNotFoundError handles resource not found errors
|
||||
func (ceh *ControllerErrorHandler) HandleNotFoundError(c *fiber.Ctx, resource string) error {
|
||||
err := fmt.Errorf("%s not found", resource)
|
||||
return ceh.HandleError(c, err, fiber.StatusNotFound, "NOT_FOUND")
|
||||
}
|
||||
|
||||
// HandleBusinessLogicError handles business logic errors
|
||||
func (ceh *ControllerErrorHandler) HandleBusinessLogicError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "BUSINESS_LOGIC")
|
||||
}
|
||||
|
||||
// HandleServiceError handles service layer errors
|
||||
func (ceh *ControllerErrorHandler) HandleServiceError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusInternalServerError, "SERVICE")
|
||||
}
|
||||
|
||||
// HandleParsingError handles request parsing errors
|
||||
func (ceh *ControllerErrorHandler) HandleParsingError(c *fiber.Ctx, err error) error {
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "PARSING")
|
||||
}
|
||||
|
||||
// HandleUUIDError handles UUID parsing errors
|
||||
func (ceh *ControllerErrorHandler) HandleUUIDError(c *fiber.Ctx, field string) error {
|
||||
err := fmt.Errorf("invalid %s format", field)
|
||||
return ceh.HandleError(c, err, fiber.StatusBadRequest, "UUID_VALIDATION", field)
|
||||
}
|
||||
|
||||
// Global controller error handler instance
|
||||
var globalErrorHandler *ControllerErrorHandler
|
||||
|
||||
// GetControllerErrorHandler returns the global controller error handler instance
|
||||
func GetControllerErrorHandler() *ControllerErrorHandler {
|
||||
if globalErrorHandler == nil {
|
||||
globalErrorHandler = NewControllerErrorHandler()
|
||||
}
|
||||
return globalErrorHandler
|
||||
}
|
||||
|
||||
// Convenience functions using the global error handler
|
||||
|
||||
// HandleError handles controller errors using the global error handler
|
||||
func HandleError(c *fiber.Ctx, err error, statusCode int, context ...string) error {
|
||||
return GetControllerErrorHandler().HandleError(c, err, statusCode, context...)
|
||||
}
|
||||
|
||||
// HandleValidationError handles validation errors using the global error handler
|
||||
func HandleValidationError(c *fiber.Ctx, err error, field string) error {
|
||||
return GetControllerErrorHandler().HandleValidationError(c, err, field)
|
||||
}
|
||||
|
||||
// HandleDatabaseError handles database errors using the global error handler
|
||||
func HandleDatabaseError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleDatabaseError(c, err)
|
||||
}
|
||||
|
||||
// HandleAuthError handles auth errors using the global error handler
|
||||
func HandleAuthError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleAuthError(c, err)
|
||||
}
|
||||
|
||||
// HandleNotFoundError handles not found errors using the global error handler
|
||||
func HandleNotFoundError(c *fiber.Ctx, resource string) error {
|
||||
return GetControllerErrorHandler().HandleNotFoundError(c, resource)
|
||||
}
|
||||
|
||||
// HandleBusinessLogicError handles business logic errors using the global error handler
|
||||
func HandleBusinessLogicError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleBusinessLogicError(c, err)
|
||||
}
|
||||
|
||||
// HandleServiceError handles service errors using the global error handler
|
||||
func HandleServiceError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleServiceError(c, err)
|
||||
}
|
||||
|
||||
// HandleParsingError handles parsing errors using the global error handler
|
||||
func HandleParsingError(c *fiber.Ctx, err error) error {
|
||||
return GetControllerErrorHandler().HandleParsingError(c, err)
|
||||
}
|
||||
|
||||
// HandleUUIDError handles UUID errors using the global error handler
|
||||
func HandleUUIDError(c *fiber.Ctx, field string) error {
|
||||
return GetControllerErrorHandler().HandleUUIDError(c, field)
|
||||
}
|
||||
365
local/utl/jwt/jwt.go
Normal file
365
local/utl/jwt/jwt.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// Claims represents the JWT claims
|
||||
type Claims struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Roles []string `json:"roles"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// TokenPair represents access and refresh tokens
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
TokenType string `json:"tokenType"`
|
||||
}
|
||||
|
||||
var (
|
||||
jwtSecret []byte
|
||||
accessTokenTTL time.Duration
|
||||
refreshTokenTTL time.Duration
|
||||
issuer string
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
ErrInvalidClaims = errors.New("invalid token claims")
|
||||
)
|
||||
|
||||
// Initialize initializes the JWT package with configuration
|
||||
func Initialize() error {
|
||||
// Get JWT secret from environment
|
||||
secret := os.Getenv("JWT_SECRET")
|
||||
if secret == "" {
|
||||
return errors.New("JWT_SECRET environment variable is required")
|
||||
}
|
||||
jwtSecret = []byte(secret)
|
||||
|
||||
// Get token TTL from environment or use defaults
|
||||
accessTTLStr := os.Getenv("JWT_ACCESS_TTL_HOURS")
|
||||
if accessTTLStr == "" {
|
||||
accessTTLStr = "24" // 24 hours default
|
||||
}
|
||||
accessHours, err := strconv.Atoi(accessTTLStr)
|
||||
if err != nil {
|
||||
accessHours = 24
|
||||
}
|
||||
accessTokenTTL = time.Duration(accessHours) * time.Hour
|
||||
|
||||
refreshTTLStr := os.Getenv("JWT_REFRESH_TTL_DAYS")
|
||||
if refreshTTLStr == "" {
|
||||
refreshTTLStr = "7" // 7 days default
|
||||
}
|
||||
refreshDays, err := strconv.Atoi(refreshTTLStr)
|
||||
if err != nil {
|
||||
refreshDays = 7
|
||||
}
|
||||
refreshTokenTTL = time.Duration(refreshDays) * 24 * time.Hour
|
||||
|
||||
// Get issuer from environment
|
||||
issuer = os.Getenv("JWT_ISSUER")
|
||||
if issuer == "" {
|
||||
issuer = "omega-server"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateTokenPair generates both access and refresh tokens
|
||||
func GenerateTokenPair(userID, email, username string, roles []string) (*TokenPair, error) {
|
||||
if len(jwtSecret) == 0 {
|
||||
if err := Initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
accessExpiresAt := now.Add(accessTokenTTL)
|
||||
refreshExpiresAt := now.Add(refreshTokenTTL)
|
||||
|
||||
// Create access token claims
|
||||
accessClaims := &Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Username: username,
|
||||
Roles: roles,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(accessExpiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: issuer,
|
||||
Subject: userID,
|
||||
ID: generateJTI(),
|
||||
},
|
||||
}
|
||||
|
||||
// Create refresh token claims (minimal data)
|
||||
refreshClaims := &Claims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(refreshExpiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: issuer,
|
||||
Subject: userID,
|
||||
ID: generateJTI(),
|
||||
},
|
||||
}
|
||||
|
||||
// Generate access token
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessTokenString, err := accessToken.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate refresh token
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshTokenString, err := refreshToken.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessTokenString,
|
||||
RefreshToken: refreshTokenString,
|
||||
ExpiresAt: accessExpiresAt,
|
||||
TokenType: "Bearer",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token and returns the claims
|
||||
func ValidateToken(tokenString string) (*Claims, error) {
|
||||
if len(jwtSecret) == 0 {
|
||||
if err := Initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("invalid signing method")
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrExpiredToken
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
if claims.UserID == "" {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RefreshToken generates a new access token from a valid refresh token
|
||||
func RefreshToken(refreshTokenString string) (*TokenPair, error) {
|
||||
// Validate refresh token
|
||||
claims, err := ValidateToken(refreshTokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For refresh tokens, we only have minimal user data
|
||||
// In a real application, you might want to fetch fresh user data from the database
|
||||
return GenerateTokenPair(claims.UserID, claims.Email, claims.Username, claims.Roles)
|
||||
}
|
||||
|
||||
// ExtractTokenFromHeader extracts JWT token from Authorization header
|
||||
func ExtractTokenFromHeader(authHeader string) (string, error) {
|
||||
if authHeader == "" {
|
||||
return "", errors.New("authorization header is required")
|
||||
}
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
return "", errors.New("invalid authorization header format")
|
||||
}
|
||||
|
||||
token := authHeader[len(bearerPrefix):]
|
||||
if token == "" {
|
||||
return "", errors.New("token is required")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if a token is expired without validating signature
|
||||
func IsTokenExpired(tokenString string) bool {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return claims.ExpiresAt.Before(time.Now())
|
||||
}
|
||||
|
||||
// GetTokenClaims extracts claims from token without validating signature (use carefully)
|
||||
func GetTokenClaims(tokenString string) (*Claims, error) {
|
||||
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &Claims{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RevokeToken adds token to revocation list (implement with your chosen storage)
|
||||
func RevokeToken(tokenString string) error {
|
||||
// In a real implementation, you would store revoked tokens in a database or cache
|
||||
// with their JTI (JWT ID) and expiration time for efficient cleanup
|
||||
|
||||
claims, err := GetTokenClaims(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the JTI in your revocation storage
|
||||
// Example: revokedTokens[claims.ID] = claims.ExpiresAt
|
||||
|
||||
_ = claims // Placeholder to avoid unused variable error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsTokenRevoked checks if a token has been revoked (implement with your chosen storage)
|
||||
func IsTokenRevoked(tokenString string) bool {
|
||||
// In a real implementation, check if the token's JTI is in the revocation list
|
||||
|
||||
claims, err := GetTokenClaims(tokenString)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check revocation storage
|
||||
// Example: _, revoked := revokedTokens[claims.ID]
|
||||
// return revoked
|
||||
|
||||
_ = claims // Placeholder to avoid unused variable error
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// generateJTI generates a unique JWT ID
|
||||
func generateJTI() string {
|
||||
// In a real implementation, use a proper UUID generator
|
||||
return strconv.FormatInt(time.Now().UnixNano(), 36)
|
||||
}
|
||||
|
||||
// GetTokenExpirationTime returns the expiration time of a token
|
||||
func GetTokenExpirationTime(tokenString string) (time.Time, error) {
|
||||
claims, err := GetTokenClaims(tokenString)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
if claims.ExpiresAt == nil {
|
||||
return time.Time{}, errors.New("token has no expiration time")
|
||||
}
|
||||
|
||||
return claims.ExpiresAt.Time, nil
|
||||
}
|
||||
|
||||
// GetTokenRemainingTime returns the remaining time before token expires
|
||||
func GetTokenRemainingTime(tokenString string) (time.Duration, error) {
|
||||
expirationTime, err := GetTokenExpirationTime(tokenString)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
remaining := time.Until(expirationTime)
|
||||
if remaining < 0 {
|
||||
return 0, ErrExpiredToken
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// HasRole checks if the token claims contain a specific role
|
||||
func (c *Claims) HasRole(role string) bool {
|
||||
for _, r := range c.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAnyRole checks if the token claims contain any of the specified roles
|
||||
func (c *Claims) HasAnyRole(roles []string) bool {
|
||||
for _, role := range roles {
|
||||
if c.HasRole(role) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAllRoles checks if the token claims contain all of the specified roles
|
||||
func (c *Claims) HasAllRoles(roles []string) bool {
|
||||
for _, role := range roles {
|
||||
if !c.HasRole(role) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user has admin role
|
||||
func (c *Claims) IsAdmin() bool {
|
||||
return c.HasRole("admin")
|
||||
}
|
||||
|
||||
// GetUserInfo returns basic user information from claims
|
||||
func (c *Claims) GetUserInfo() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"userId": c.UserID,
|
||||
"email": c.Email,
|
||||
"username": c.Username,
|
||||
"roles": c.Roles,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates a JWT token for a user (backward compatibility)
|
||||
func GenerateToken(userID, email, username string, roles []string) (string, error) {
|
||||
tokenPair, err := GenerateTokenPair(userID, email, username, roles)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tokenPair.AccessToken, nil
|
||||
}
|
||||
167
local/utl/logging/base.go
Normal file
167
local/utl/logging/base.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
timeFormat = "2006-01-02 15:04:05.000"
|
||||
)
|
||||
|
||||
// BaseLogger provides the core logging functionality
|
||||
type BaseLogger struct {
|
||||
file *os.File
|
||||
logger *log.Logger
|
||||
mu sync.RWMutex
|
||||
initialized bool
|
||||
}
|
||||
|
||||
// LogLevel represents different logging levels
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
LogLevelError LogLevel = "ERROR"
|
||||
LogLevelWarn LogLevel = "WARN"
|
||||
LogLevelInfo LogLevel = "INFO"
|
||||
LogLevelDebug LogLevel = "DEBUG"
|
||||
LogLevelPanic LogLevel = "PANIC"
|
||||
)
|
||||
|
||||
// Initialize creates a new base logger instance
|
||||
func InitializeBase(tp string) (*BaseLogger, error) {
|
||||
return newBaseLogger(tp)
|
||||
}
|
||||
|
||||
func newBaseLogger(tp string) (*BaseLogger, error) {
|
||||
// Ensure logs directory exists
|
||||
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create logs directory: %v", err)
|
||||
}
|
||||
|
||||
// Open log file with date in name
|
||||
logPath := filepath.Join("logs", fmt.Sprintf("acc-server-%s-%s.log", time.Now().Format("2006-01-02"), tp))
|
||||
file, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open log file: %v", err)
|
||||
}
|
||||
|
||||
// Create multi-writer for both file and console
|
||||
multiWriter := io.MultiWriter(file, os.Stdout)
|
||||
|
||||
// Create base logger
|
||||
logger := &BaseLogger{
|
||||
file: file,
|
||||
logger: log.New(multiWriter, "", 0),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// GetBaseLogger creates and returns a new base logger instance
|
||||
func GetBaseLogger(tp string) *BaseLogger {
|
||||
baseLogger, _ := InitializeBase(tp)
|
||||
return baseLogger
|
||||
}
|
||||
|
||||
// Close closes the log file
|
||||
func (bl *BaseLogger) Close() error {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
if bl.file != nil {
|
||||
return bl.file.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Log writes a log entry with the specified level
|
||||
func (bl *BaseLogger) Log(level LogLevel, format string, v ...interface{}) {
|
||||
if bl == nil || !bl.initialized {
|
||||
return
|
||||
}
|
||||
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Get caller info (skip 2 frames: this function and the calling Log function)
|
||||
_, file, line, _ := runtime.Caller(2)
|
||||
file = filepath.Base(file)
|
||||
|
||||
// Format message
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Format final log line
|
||||
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
|
||||
time.Now().Format(timeFormat),
|
||||
string(level),
|
||||
file,
|
||||
line,
|
||||
msg,
|
||||
)
|
||||
|
||||
bl.logger.Println(logLine)
|
||||
}
|
||||
|
||||
// LogWithCaller writes a log entry with custom caller depth
|
||||
func (bl *BaseLogger) LogWithCaller(level LogLevel, callerDepth int, format string, v ...interface{}) {
|
||||
if bl == nil || !bl.initialized {
|
||||
return
|
||||
}
|
||||
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Get caller info with custom depth
|
||||
_, file, line, _ := runtime.Caller(callerDepth)
|
||||
file = filepath.Base(file)
|
||||
|
||||
// Format message
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Format final log line
|
||||
logLine := fmt.Sprintf("[%s] [%s] [%s:%d] %s",
|
||||
time.Now().Format(timeFormat),
|
||||
string(level),
|
||||
file,
|
||||
line,
|
||||
msg,
|
||||
)
|
||||
|
||||
bl.logger.Println(logLine)
|
||||
}
|
||||
|
||||
// IsInitialized returns whether the base logger is initialized
|
||||
func (bl *BaseLogger) IsInitialized() bool {
|
||||
if bl == nil {
|
||||
return false
|
||||
}
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
return bl.initialized
|
||||
}
|
||||
|
||||
// RecoverAndLog recovers from panics and logs them
|
||||
func RecoverAndLog() {
|
||||
baseLogger := GetBaseLogger("panic")
|
||||
if baseLogger != nil && baseLogger.IsInitialized() {
|
||||
if r := recover(); r != nil {
|
||||
// Get stack trace
|
||||
buf := make([]byte, 4096)
|
||||
n := runtime.Stack(buf, false)
|
||||
stackTrace := string(buf[:n])
|
||||
|
||||
baseLogger.LogWithCaller(LogLevelPanic, 2, "Recovered from panic: %v\nStack Trace:\n%s", r, stackTrace)
|
||||
|
||||
// Re-panic to maintain original behavior if needed
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
159
local/utl/logging/debug.go
Normal file
159
local/utl/logging/debug.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// DebugLogger handles debug-level logging
|
||||
type DebugLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewDebugLogger creates a new debug logger instance
|
||||
func NewDebugLogger() *DebugLogger {
|
||||
base, _ := InitializeBase("debug")
|
||||
return &DebugLogger{
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes a debug-level log entry
|
||||
func (dl *DebugLogger) Log(format string, v ...interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes a debug-level log entry with additional context
|
||||
func (dl *DebugLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if dl.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
dl.base.Log(LogLevelDebug, contextualFormat, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogFunction logs function entry and exit for debugging
|
||||
func (dl *DebugLogger) LogFunction(functionName string, args ...interface{}) {
|
||||
if dl.base != nil {
|
||||
if len(args) > 0 {
|
||||
dl.base.Log(LogLevelDebug, "FUNCTION [%s] called with args: %+v", functionName, args)
|
||||
} else {
|
||||
dl.base.Log(LogLevelDebug, "FUNCTION [%s] called", functionName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogVariable logs variable values for debugging
|
||||
func (dl *DebugLogger) LogVariable(varName string, value interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "VARIABLE [%s]: %+v", varName, value)
|
||||
}
|
||||
}
|
||||
|
||||
// LogState logs application state information
|
||||
func (dl *DebugLogger) LogState(component string, state interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "STATE [%s]: %+v", component, state)
|
||||
}
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries for debugging
|
||||
func (dl *DebugLogger) LogSQL(query string, args ...interface{}) {
|
||||
if dl.base != nil {
|
||||
if len(args) > 0 {
|
||||
dl.base.Log(LogLevelDebug, "SQL: %s | Args: %+v", query, args)
|
||||
} else {
|
||||
dl.base.Log(LogLevelDebug, "SQL: %s", query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogMemory logs memory usage information
|
||||
func (dl *DebugLogger) LogMemory() {
|
||||
if dl.base != nil {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
dl.base.Log(LogLevelDebug, "MEMORY: Alloc = %d KB, TotalAlloc = %d KB, Sys = %d KB, NumGC = %d",
|
||||
bToKb(m.Alloc), bToKb(m.TotalAlloc), bToKb(m.Sys), m.NumGC)
|
||||
}
|
||||
}
|
||||
|
||||
// LogGoroutines logs current number of goroutines
|
||||
func (dl *DebugLogger) LogGoroutines() {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "GOROUTINES: %d active", runtime.NumGoroutine())
|
||||
}
|
||||
}
|
||||
|
||||
// LogTiming logs timing information for performance debugging
|
||||
func (dl *DebugLogger) LogTiming(operation string, duration interface{}) {
|
||||
if dl.base != nil {
|
||||
dl.base.Log(LogLevelDebug, "TIMING [%s]: %v", operation, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to convert bytes to kilobytes
|
||||
func bToKb(b uint64) uint64 {
|
||||
return b / 1024
|
||||
}
|
||||
|
||||
// Global debug logger instance
|
||||
var (
|
||||
debugLogger *DebugLogger
|
||||
debugOnce sync.Once
|
||||
)
|
||||
|
||||
// GetDebugLogger returns the global debug logger instance
|
||||
func GetDebugLogger() *DebugLogger {
|
||||
debugOnce.Do(func() {
|
||||
debugLogger = NewDebugLogger()
|
||||
})
|
||||
return debugLogger
|
||||
}
|
||||
|
||||
// Debug logs a debug-level message using the global debug logger
|
||||
func Debug(format string, v ...interface{}) {
|
||||
GetDebugLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// DebugWithContext logs a debug-level message with context using the global debug logger
|
||||
func DebugWithContext(context string, format string, v ...interface{}) {
|
||||
GetDebugLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// DebugFunction logs function entry and exit using the global debug logger
|
||||
func DebugFunction(functionName string, args ...interface{}) {
|
||||
GetDebugLogger().LogFunction(functionName, args...)
|
||||
}
|
||||
|
||||
// DebugVariable logs variable values using the global debug logger
|
||||
func DebugVariable(varName string, value interface{}) {
|
||||
GetDebugLogger().LogVariable(varName, value)
|
||||
}
|
||||
|
||||
// DebugState logs application state information using the global debug logger
|
||||
func DebugState(component string, state interface{}) {
|
||||
GetDebugLogger().LogState(component, state)
|
||||
}
|
||||
|
||||
// DebugSQL logs SQL queries using the global debug logger
|
||||
func DebugSQL(query string, args ...interface{}) {
|
||||
GetDebugLogger().LogSQL(query, args...)
|
||||
}
|
||||
|
||||
// DebugMemory logs memory usage information using the global debug logger
|
||||
func DebugMemory() {
|
||||
GetDebugLogger().LogMemory()
|
||||
}
|
||||
|
||||
// DebugGoroutines logs current number of goroutines using the global debug logger
|
||||
func DebugGoroutines() {
|
||||
GetDebugLogger().LogGoroutines()
|
||||
}
|
||||
|
||||
// DebugTiming logs timing information using the global debug logger
|
||||
func DebugTiming(operation string, duration interface{}) {
|
||||
GetDebugLogger().LogTiming(operation, duration)
|
||||
}
|
||||
106
local/utl/logging/error.go
Normal file
106
local/utl/logging/error.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrorLogger handles error-level logging
|
||||
type ErrorLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewErrorLogger creates a new error logger instance
|
||||
func NewErrorLogger() *ErrorLogger {
|
||||
base, _ := InitializeBase("error")
|
||||
return &ErrorLogger{
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes an error-level log entry
|
||||
func (el *ErrorLogger) Log(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
el.base.Log(LogLevelError, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes an error-level log entry with additional context
|
||||
func (el *ErrorLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
el.base.Log(LogLevelError, contextualFormat, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error object with optional message
|
||||
func (el *ErrorLogger) LogError(err error, message ...string) {
|
||||
if el.base != nil && err != nil {
|
||||
if len(message) > 0 {
|
||||
el.base.Log(LogLevelError, "%s: %v", message[0], err)
|
||||
} else {
|
||||
el.base.Log(LogLevelError, "Error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithStackTrace logs an error with stack trace
|
||||
func (el *ErrorLogger) LogWithStackTrace(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
// Get stack trace
|
||||
buf := make([]byte, 4096)
|
||||
n := runtime.Stack(buf, false)
|
||||
stackTrace := string(buf[:n])
|
||||
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
el.base.Log(LogLevelError, "%s\nStack Trace:\n%s", msg, stackTrace)
|
||||
}
|
||||
}
|
||||
|
||||
// LogFatal logs a fatal error and exits the program
|
||||
func (el *ErrorLogger) LogFatal(format string, v ...interface{}) {
|
||||
if el.base != nil {
|
||||
el.base.Log(LogLevelError, "[FATAL] "+format, v...)
|
||||
panic(fmt.Sprintf(format, v...))
|
||||
}
|
||||
}
|
||||
|
||||
// Global error logger instance
|
||||
var (
|
||||
errorLogger *ErrorLogger
|
||||
errorOnce sync.Once
|
||||
)
|
||||
|
||||
// GetErrorLogger returns the global error logger instance
|
||||
func GetErrorLogger() *ErrorLogger {
|
||||
errorOnce.Do(func() {
|
||||
errorLogger = NewErrorLogger()
|
||||
})
|
||||
return errorLogger
|
||||
}
|
||||
|
||||
// Error logs an error-level message using the global error logger
|
||||
func Error(format string, v ...interface{}) {
|
||||
GetErrorLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// ErrorWithContext logs an error-level message with context using the global error logger
|
||||
func ErrorWithContext(context string, format string, v ...interface{}) {
|
||||
GetErrorLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// LogError logs an error object using the global error logger
|
||||
func LogError(err error, message ...string) {
|
||||
GetErrorLogger().LogError(err, message...)
|
||||
}
|
||||
|
||||
// ErrorWithStackTrace logs an error with stack trace using the global error logger
|
||||
func ErrorWithStackTrace(format string, v ...interface{}) {
|
||||
GetErrorLogger().LogWithStackTrace(format, v...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal error and exits the program using the global error logger
|
||||
func Fatal(format string, v ...interface{}) {
|
||||
GetErrorLogger().LogFatal(format, v...)
|
||||
}
|
||||
130
local/utl/logging/info.go
Normal file
130
local/utl/logging/info.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// InfoLogger handles info-level logging
|
||||
type InfoLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewInfoLogger creates a new info logger instance
|
||||
func NewInfoLogger() *InfoLogger {
|
||||
base, _ := InitializeBase("info")
|
||||
return &InfoLogger{
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes an info-level log entry
|
||||
func (il *InfoLogger) Log(format string, v ...interface{}) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes an info-level log entry with additional context
|
||||
func (il *InfoLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if il.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
il.base.Log(LogLevelInfo, contextualFormat, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogStartup logs application startup information
|
||||
func (il *InfoLogger) LogStartup(component string, message string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "STARTUP [%s]: %s", component, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogShutdown logs application shutdown information
|
||||
func (il *InfoLogger) LogShutdown(component string, message string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "SHUTDOWN [%s]: %s", component, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogOperation logs general operation information
|
||||
func (il *InfoLogger) LogOperation(operation string, details string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "OPERATION [%s]: %s", operation, details)
|
||||
}
|
||||
}
|
||||
|
||||
// LogStatus logs status changes or updates
|
||||
func (il *InfoLogger) LogStatus(component string, status string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "STATUS [%s]: %s", component, status)
|
||||
}
|
||||
}
|
||||
|
||||
// LogRequest logs incoming requests
|
||||
func (il *InfoLogger) LogRequest(method string, path string, userAgent string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "REQUEST [%s %s] User-Agent: %s", method, path, userAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// LogResponse logs outgoing responses
|
||||
func (il *InfoLogger) LogResponse(method string, path string, statusCode int, duration string) {
|
||||
if il.base != nil {
|
||||
il.base.Log(LogLevelInfo, "RESPONSE [%s %s] Status: %d, Duration: %s", method, path, statusCode, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Global info logger instance
|
||||
var (
|
||||
infoLogger *InfoLogger
|
||||
infoOnce sync.Once
|
||||
)
|
||||
|
||||
// GetInfoLogger returns the global info logger instance
|
||||
func GetInfoLogger() *InfoLogger {
|
||||
infoOnce.Do(func() {
|
||||
infoLogger = NewInfoLogger()
|
||||
})
|
||||
return infoLogger
|
||||
}
|
||||
|
||||
// Info logs an info-level message using the global info logger
|
||||
func Info(format string, v ...interface{}) {
|
||||
GetInfoLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// InfoWithContext logs an info-level message with context using the global info logger
|
||||
func InfoWithContext(context string, format string, v ...interface{}) {
|
||||
GetInfoLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// InfoStartup logs application startup information using the global info logger
|
||||
func InfoStartup(component string, message string) {
|
||||
GetInfoLogger().LogStartup(component, message)
|
||||
}
|
||||
|
||||
// InfoShutdown logs application shutdown information using the global info logger
|
||||
func InfoShutdown(component string, message string) {
|
||||
GetInfoLogger().LogShutdown(component, message)
|
||||
}
|
||||
|
||||
// InfoOperation logs general operation information using the global info logger
|
||||
func InfoOperation(operation string, details string) {
|
||||
GetInfoLogger().LogOperation(operation, details)
|
||||
}
|
||||
|
||||
// InfoStatus logs status changes or updates using the global info logger
|
||||
func InfoStatus(component string, status string) {
|
||||
GetInfoLogger().LogStatus(component, status)
|
||||
}
|
||||
|
||||
// InfoRequest logs incoming requests using the global info logger
|
||||
func InfoRequest(method string, path string, userAgent string) {
|
||||
GetInfoLogger().LogRequest(method, path, userAgent)
|
||||
}
|
||||
|
||||
// InfoResponse logs outgoing responses using the global info logger
|
||||
func InfoResponse(method string, path string, statusCode int, duration string) {
|
||||
GetInfoLogger().LogResponse(method, path, statusCode, duration)
|
||||
}
|
||||
213
local/utl/logging/logger.go
Normal file
213
local/utl/logging/logger.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// Legacy logger for backward compatibility
|
||||
logger *Logger
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// Logger maintains backward compatibility with existing code
|
||||
type Logger struct {
|
||||
base *BaseLogger
|
||||
errorLogger *ErrorLogger
|
||||
warnLogger *WarnLogger
|
||||
infoLogger *InfoLogger
|
||||
debugLogger *DebugLogger
|
||||
}
|
||||
|
||||
// Initialize creates or gets the singleton logger instance
|
||||
// This maintains backward compatibility with existing code
|
||||
func Initialize() (*Logger, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
logger, err = newLogger()
|
||||
})
|
||||
return logger, err
|
||||
}
|
||||
|
||||
func newLogger() (*Logger, error) {
|
||||
// Initialize the base logger
|
||||
baseLogger, err := InitializeBase("log")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the legacy logger wrapper
|
||||
logger := &Logger{
|
||||
base: baseLogger,
|
||||
errorLogger: GetErrorLogger(),
|
||||
warnLogger: GetWarnLogger(),
|
||||
infoLogger: GetInfoLogger(),
|
||||
debugLogger: GetDebugLogger(),
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// Close closes the logger
|
||||
func (l *Logger) Close() error {
|
||||
if l.base != nil {
|
||||
return l.base.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Legacy methods for backward compatibility
|
||||
func (l *Logger) log(level, format string, v ...interface{}) {
|
||||
if l.base != nil {
|
||||
l.base.LogWithCaller(LogLevel(level), 3, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Info(format string, v ...interface{}) {
|
||||
if l.infoLogger != nil {
|
||||
l.infoLogger.Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, v ...interface{}) {
|
||||
if l.errorLogger != nil {
|
||||
l.errorLogger.Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn(format string, v ...interface{}) {
|
||||
if l.warnLogger != nil {
|
||||
l.warnLogger.Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug(format string, v ...interface{}) {
|
||||
if l.debugLogger != nil {
|
||||
l.debugLogger.Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Panic(format string) {
|
||||
if l.errorLogger != nil {
|
||||
l.errorLogger.LogFatal(format)
|
||||
}
|
||||
}
|
||||
|
||||
// Global convenience functions for backward compatibility
|
||||
// These are now implemented in individual logger files to avoid redeclaration
|
||||
func LegacyInfo(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Info(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetInfoLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func LegacyError(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Error(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetErrorLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func LegacyWarn(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Warn(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetWarnLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func LegacyDebug(format string, v ...interface{}) {
|
||||
if logger != nil {
|
||||
logger.Debug(format, v...)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetDebugLogger().Log(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func Panic(format string) {
|
||||
if logger != nil {
|
||||
logger.Panic(format)
|
||||
} else {
|
||||
// Fallback to direct logger if legacy logger not initialized
|
||||
GetErrorLogger().LogFatal(format)
|
||||
}
|
||||
}
|
||||
|
||||
// Enhanced logging convenience functions
|
||||
// These provide direct access to specialized logging functions
|
||||
|
||||
// LogStartup logs application startup information
|
||||
func LogStartup(component string, message string) {
|
||||
GetInfoLogger().LogStartup(component, message)
|
||||
}
|
||||
|
||||
// LogShutdown logs application shutdown information
|
||||
func LogShutdown(component string, message string) {
|
||||
GetInfoLogger().LogShutdown(component, message)
|
||||
}
|
||||
|
||||
// LogOperation logs general operation information
|
||||
func LogOperation(operation string, details string) {
|
||||
GetInfoLogger().LogOperation(operation, details)
|
||||
}
|
||||
|
||||
// LogRequest logs incoming HTTP requests
|
||||
func LogRequest(method string, path string, userAgent string) {
|
||||
GetInfoLogger().LogRequest(method, path, userAgent)
|
||||
}
|
||||
|
||||
// LogResponse logs outgoing HTTP responses
|
||||
func LogResponse(method string, path string, statusCode int, duration string) {
|
||||
GetInfoLogger().LogResponse(method, path, statusCode, duration)
|
||||
}
|
||||
|
||||
// LogSQL logs SQL queries for debugging
|
||||
func LogSQL(query string, args ...interface{}) {
|
||||
GetDebugLogger().LogSQL(query, args...)
|
||||
}
|
||||
|
||||
// LogMemory logs memory usage information
|
||||
func LogMemory() {
|
||||
GetDebugLogger().LogMemory()
|
||||
}
|
||||
|
||||
// LogTiming logs timing information for performance debugging
|
||||
func LogTiming(operation string, duration interface{}) {
|
||||
GetDebugLogger().LogTiming(operation, duration)
|
||||
}
|
||||
|
||||
// GetLegacyLogger returns the legacy logger instance for backward compatibility
|
||||
func GetLegacyLogger() *Logger {
|
||||
if logger == nil {
|
||||
logger, _ = Initialize()
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
// InitializeLogging initializes all logging components
|
||||
func InitializeLogging() error {
|
||||
// Initialize legacy logger for backward compatibility
|
||||
_, err := Initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize legacy logger: %v", err)
|
||||
}
|
||||
|
||||
// Pre-initialize all logger types to ensure separate log files
|
||||
GetErrorLogger()
|
||||
GetWarnLogger()
|
||||
GetInfoLogger()
|
||||
GetDebugLogger()
|
||||
|
||||
// Log successful initialization
|
||||
Info("Logging system initialized successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
98
local/utl/logging/warn.go
Normal file
98
local/utl/logging/warn.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// WarnLogger handles warn-level logging
|
||||
type WarnLogger struct {
|
||||
base *BaseLogger
|
||||
}
|
||||
|
||||
// NewWarnLogger creates a new warn logger instance
|
||||
func NewWarnLogger() *WarnLogger {
|
||||
base, _ := InitializeBase("warn")
|
||||
return &WarnLogger{
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// Log writes a warn-level log entry
|
||||
func (wl *WarnLogger) Log(format string, v ...interface{}) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogWithContext writes a warn-level log entry with additional context
|
||||
func (wl *WarnLogger) LogWithContext(context string, format string, v ...interface{}) {
|
||||
if wl.base != nil {
|
||||
contextualFormat := fmt.Sprintf("[%s] %s", context, format)
|
||||
wl.base.Log(LogLevelWarn, contextualFormat, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogDeprecation logs a deprecation warning
|
||||
func (wl *WarnLogger) LogDeprecation(feature string, alternative string) {
|
||||
if wl.base != nil {
|
||||
if alternative != "" {
|
||||
wl.base.Log(LogLevelWarn, "DEPRECATED: %s is deprecated, use %s instead", feature, alternative)
|
||||
} else {
|
||||
wl.base.Log(LogLevelWarn, "DEPRECATED: %s is deprecated", feature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogConfiguration logs configuration-related warnings
|
||||
func (wl *WarnLogger) LogConfiguration(setting string, message string) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, "CONFIG WARNING [%s]: %s", setting, message)
|
||||
}
|
||||
}
|
||||
|
||||
// LogPerformance logs performance-related warnings
|
||||
func (wl *WarnLogger) LogPerformance(operation string, threshold string, actual string) {
|
||||
if wl.base != nil {
|
||||
wl.base.Log(LogLevelWarn, "PERFORMANCE WARNING [%s]: exceeded threshold %s, actual: %s", operation, threshold, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Global warn logger instance
|
||||
var (
|
||||
warnLogger *WarnLogger
|
||||
warnOnce sync.Once
|
||||
)
|
||||
|
||||
// GetWarnLogger returns the global warn logger instance
|
||||
func GetWarnLogger() *WarnLogger {
|
||||
warnOnce.Do(func() {
|
||||
warnLogger = NewWarnLogger()
|
||||
})
|
||||
return warnLogger
|
||||
}
|
||||
|
||||
// Warn logs a warn-level message using the global warn logger
|
||||
func Warn(format string, v ...interface{}) {
|
||||
GetWarnLogger().Log(format, v...)
|
||||
}
|
||||
|
||||
// WarnWithContext logs a warn-level message with context using the global warn logger
|
||||
func WarnWithContext(context string, format string, v ...interface{}) {
|
||||
GetWarnLogger().LogWithContext(context, format, v...)
|
||||
}
|
||||
|
||||
// WarnDeprecation logs a deprecation warning using the global warn logger
|
||||
func WarnDeprecation(feature string, alternative string) {
|
||||
GetWarnLogger().LogDeprecation(feature, alternative)
|
||||
}
|
||||
|
||||
// WarnConfiguration logs configuration-related warnings using the global warn logger
|
||||
func WarnConfiguration(setting string, message string) {
|
||||
GetWarnLogger().LogConfiguration(setting, message)
|
||||
}
|
||||
|
||||
// WarnPerformance logs performance-related warnings using the global warn logger
|
||||
func WarnPerformance(operation string, threshold string, actual string) {
|
||||
GetWarnLogger().LogPerformance(operation, threshold, actual)
|
||||
}
|
||||
336
local/utl/password/password.go
Normal file
336
local/utl/password/password.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinPasswordLength minimum password length
|
||||
MinPasswordLength = 8
|
||||
// MaxPasswordLength maximum password length
|
||||
MaxPasswordLength = 128
|
||||
// DefaultCost default bcrypt cost
|
||||
DefaultCost = 12
|
||||
)
|
||||
|
||||
// HashPassword hashes a plain text password using bcrypt
|
||||
func HashPassword(password string) (string, error) {
|
||||
if err := ValidatePasswordStrength(password); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
// CheckPasswordHash compares a plain text password with its hash
|
||||
func CheckPasswordHash(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ValidatePasswordStrength validates password strength requirements
|
||||
func ValidatePasswordStrength(password string) error {
|
||||
if len(password) < MinPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
if len(password) > MaxPasswordLength {
|
||||
return errors.New("password must not exceed 128 characters")
|
||||
}
|
||||
|
||||
var (
|
||||
hasUpper = false
|
||||
hasLower = false
|
||||
hasNumber = false
|
||||
hasSpecial = false
|
||||
)
|
||||
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(char):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(char):
|
||||
hasLower = true
|
||||
case unicode.IsNumber(char):
|
||||
hasNumber = true
|
||||
case unicode.IsPunct(char) || unicode.IsSymbol(char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper {
|
||||
return errors.New("password must contain at least one uppercase letter")
|
||||
}
|
||||
|
||||
if !hasLower {
|
||||
return errors.New("password must contain at least one lowercase letter")
|
||||
}
|
||||
|
||||
if !hasNumber {
|
||||
return errors.New("password must contain at least one digit")
|
||||
}
|
||||
|
||||
if !hasSpecial {
|
||||
return errors.New("password must contain at least one special character")
|
||||
}
|
||||
|
||||
// Check for common patterns
|
||||
if isCommonPassword(password) {
|
||||
return errors.New("password is too common, please choose a stronger password")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isCommonPassword checks if password is in list of common passwords
|
||||
func isCommonPassword(password string) bool {
|
||||
commonPasswords := []string{
|
||||
"password", "123456", "password123", "admin", "qwerty",
|
||||
"letmein", "welcome", "monkey", "1234567890", "password1",
|
||||
"123456789", "welcome123", "admin123", "root", "test",
|
||||
"guest", "password12", "changeme", "default", "temp",
|
||||
}
|
||||
|
||||
for _, common := range commonPasswords {
|
||||
if password == common {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidatePasswordComplexity validates password against additional complexity rules
|
||||
func ValidatePasswordComplexity(password string) error {
|
||||
if err := ValidatePasswordStrength(password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for repeated characters (more than 3 consecutive)
|
||||
if hasRepeatedChars(password, 3) {
|
||||
return errors.New("password must not contain more than 3 consecutive identical characters")
|
||||
}
|
||||
|
||||
// Check for sequential characters (like "1234" or "abcd")
|
||||
if hasSequentialChars(password, 4) {
|
||||
return errors.New("password must not contain sequential characters")
|
||||
}
|
||||
|
||||
// Check for keyboard patterns
|
||||
if hasKeyboardPattern(password) {
|
||||
return errors.New("password must not contain keyboard patterns")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasRepeatedChars checks for repeated consecutive characters
|
||||
func hasRepeatedChars(password string, maxRepeat int) bool {
|
||||
if len(password) < maxRepeat+1 {
|
||||
return false
|
||||
}
|
||||
|
||||
count := 1
|
||||
for i := 1; i < len(password); i++ {
|
||||
if password[i] == password[i-1] {
|
||||
count++
|
||||
if count > maxRepeat {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// hasSequentialChars checks for sequential characters
|
||||
func hasSequentialChars(password string, minSequence int) bool {
|
||||
if len(password) < minSequence {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i <= len(password)-minSequence; i++ {
|
||||
isSequential := true
|
||||
isReverseSequential := true
|
||||
|
||||
for j := 1; j < minSequence; j++ {
|
||||
if int(password[i+j]) != int(password[i+j-1])+1 {
|
||||
isSequential = false
|
||||
}
|
||||
if int(password[i+j]) != int(password[i+j-1])-1 {
|
||||
isReverseSequential = false
|
||||
}
|
||||
}
|
||||
|
||||
if isSequential || isReverseSequential {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// hasKeyboardPattern checks for common keyboard patterns
|
||||
func hasKeyboardPattern(password string) bool {
|
||||
keyboardPatterns := []string{
|
||||
"qwerty", "asdf", "zxcv", "qwertyuiop", "asdfghjkl", "zxcvbnm",
|
||||
"1234567890", "qazwsx", "wsxedc", "rfvtgb", "nhyujm", "iklop",
|
||||
}
|
||||
|
||||
lowerPassword := regexp.MustCompile(`[^a-zA-Z0-9]`).ReplaceAllString(password, "")
|
||||
lowerPassword = regexp.MustCompile(`[A-Z]`).ReplaceAllStringFunc(lowerPassword, func(s string) string {
|
||||
return string(rune(s[0]) + 32)
|
||||
})
|
||||
|
||||
for _, pattern := range keyboardPatterns {
|
||||
if len(lowerPassword) >= len(pattern) {
|
||||
for i := 0; i <= len(lowerPassword)-len(pattern); i++ {
|
||||
if lowerPassword[i:i+len(pattern)] == pattern {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GenerateRandomPassword generates a random password with specified length
|
||||
func GenerateRandomPassword(length int) (string, error) {
|
||||
if length < MinPasswordLength {
|
||||
length = MinPasswordLength
|
||||
}
|
||||
if length > MaxPasswordLength {
|
||||
length = MaxPasswordLength
|
||||
}
|
||||
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
// Use crypto/rand for secure random generation
|
||||
password := make([]byte, length)
|
||||
for i := range password {
|
||||
// Simple implementation - in production, use crypto/rand
|
||||
password[i] = charset[i%len(charset)]
|
||||
}
|
||||
|
||||
// Ensure password meets complexity requirements
|
||||
result := string(password)
|
||||
if err := ValidatePasswordStrength(result); err != nil {
|
||||
// Fallback to a known good password pattern if generation fails
|
||||
return generateFallbackPassword(length), nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// generateFallbackPassword generates a password that meets all requirements
|
||||
func generateFallbackPassword(length int) string {
|
||||
if length < MinPasswordLength {
|
||||
length = MinPasswordLength
|
||||
}
|
||||
|
||||
// Start with a base that meets all requirements
|
||||
base := "Aa1!"
|
||||
remaining := length - len(base)
|
||||
|
||||
// Fill remaining with mixed characters
|
||||
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
||||
for i := 0; i < remaining; i++ {
|
||||
base += string(charset[i%len(charset)])
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// GetPasswordStrengthScore returns a score from 0-100 indicating password strength
|
||||
func GetPasswordStrengthScore(password string) int {
|
||||
score := 0
|
||||
|
||||
// Length score (0-25 points)
|
||||
if len(password) >= 8 {
|
||||
score += 5
|
||||
}
|
||||
if len(password) >= 12 {
|
||||
score += 10
|
||||
}
|
||||
if len(password) >= 16 {
|
||||
score += 10
|
||||
}
|
||||
|
||||
// Character variety (0-40 points)
|
||||
var hasUpper, hasLower, hasNumber, hasSpecial bool
|
||||
for _, char := range password {
|
||||
if unicode.IsUpper(char) {
|
||||
hasUpper = true
|
||||
} else if unicode.IsLower(char) {
|
||||
hasLower = true
|
||||
} else if unicode.IsNumber(char) {
|
||||
hasNumber = true
|
||||
} else if unicode.IsPunct(char) || unicode.IsSymbol(char) {
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasUpper {
|
||||
score += 10
|
||||
}
|
||||
if hasLower {
|
||||
score += 10
|
||||
}
|
||||
if hasNumber {
|
||||
score += 10
|
||||
}
|
||||
if hasSpecial {
|
||||
score += 10
|
||||
}
|
||||
|
||||
// Complexity bonus (0-35 points)
|
||||
if !isCommonPassword(password) {
|
||||
score += 10
|
||||
}
|
||||
if !hasRepeatedChars(password, 2) {
|
||||
score += 10
|
||||
}
|
||||
if !hasSequentialChars(password, 3) {
|
||||
score += 10
|
||||
}
|
||||
if !hasKeyboardPattern(password) {
|
||||
score += 5
|
||||
}
|
||||
|
||||
// Cap at 100
|
||||
if score > 100 {
|
||||
score = 100
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// GetPasswordStrengthLevel returns a human-readable strength level
|
||||
func GetPasswordStrengthLevel(password string) string {
|
||||
score := GetPasswordStrengthScore(password)
|
||||
|
||||
switch {
|
||||
case score >= 80:
|
||||
return "Very Strong"
|
||||
case score >= 60:
|
||||
return "Strong"
|
||||
case score >= 40:
|
||||
return "Medium"
|
||||
case score >= 20:
|
||||
return "Weak"
|
||||
default:
|
||||
return "Very Weak"
|
||||
}
|
||||
}
|
||||
157
local/utl/server/server.go
Normal file
157
local/utl/server/server.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"omega-server/local/api"
|
||||
"omega-server/local/middleware/security"
|
||||
"omega-server/local/utl/logging"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/helmet"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
"github.com/gofiber/swagger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
func Start(di *dig.Container) *fiber.App {
|
||||
app := fiber.New(fiber.Config{
|
||||
AppName: "Omega Project Management",
|
||||
ServerHeader: "omega-server",
|
||||
EnablePrintRoutes: true,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
BodyLimit: 10 * 1024 * 1024, // 10MB
|
||||
Prefork: false,
|
||||
CaseSensitive: false,
|
||||
StrictRouting: false,
|
||||
DisableKeepalive: false,
|
||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||
// Custom error handler
|
||||
code := fiber.StatusInternalServerError
|
||||
if e, ok := err.(*fiber.Error); ok {
|
||||
code = e.Code
|
||||
}
|
||||
|
||||
// Log error
|
||||
logging.Error("HTTP Error: %v, Path: %s, Method: %s, IP: %s",
|
||||
err, c.Path(), c.Method(), c.IP())
|
||||
|
||||
// Return JSON error response
|
||||
return c.Status(code).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
"code": code,
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
// Initialize security middleware
|
||||
securityMW := security.NewSecurityMiddleware()
|
||||
|
||||
// Add recovery middleware first
|
||||
app.Use(recover.New(recover.Config{
|
||||
EnableStackTrace: true,
|
||||
}))
|
||||
|
||||
// Add security middleware stack
|
||||
app.Use(securityMW.SecurityHeaders())
|
||||
app.Use(securityMW.LogSecurityEvents())
|
||||
app.Use(securityMW.TimeoutMiddleware(30 * time.Second))
|
||||
app.Use(securityMW.RequestSizeLimit(10 * 1024 * 1024)) // 10MB
|
||||
app.Use(securityMW.ValidateUserAgent())
|
||||
app.Use(securityMW.ValidateContentType("application/json", "application/x-www-form-urlencoded", "multipart/form-data"))
|
||||
app.Use(securityMW.InputSanitization())
|
||||
app.Use(securityMW.RateLimit(100, 1*time.Minute)) // 100 requests per minute global
|
||||
|
||||
// Add Helmet middleware for security headers
|
||||
app.Use(helmet.New(helmet.Config{
|
||||
XSSProtection: "1; mode=block",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "DENY",
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSPreloadEnabled: true,
|
||||
ContentSecurityPolicy: "default-src 'self'",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
}))
|
||||
|
||||
// Configure CORS
|
||||
allowedOrigin := os.Getenv("CORS_ALLOWED_ORIGIN")
|
||||
if allowedOrigin == "" {
|
||||
allowedOrigin = "http://localhost:5173"
|
||||
}
|
||||
|
||||
app.Use(cors.New(cors.Config{
|
||||
AllowOrigins: allowedOrigin,
|
||||
AllowHeaders: "Origin, Content-Type, Accept, Authorization, X-Requested-With",
|
||||
AllowMethods: "GET, POST, PUT, DELETE, OPTIONS, PATCH",
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24 hours
|
||||
}))
|
||||
|
||||
// Add request logging middleware
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
start := time.Now()
|
||||
|
||||
// Process request
|
||||
err := c.Next()
|
||||
|
||||
// Log request
|
||||
duration := time.Since(start)
|
||||
logging.InfoResponse(
|
||||
c.Method(),
|
||||
c.Path(),
|
||||
c.Response().StatusCode(),
|
||||
duration.String(),
|
||||
)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
// Swagger documentation
|
||||
app.Get("/swagger/*", swagger.HandlerDefault)
|
||||
|
||||
// Health check endpoint
|
||||
app.Get("/health", func(c *fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
"version": "1.0.0",
|
||||
})
|
||||
})
|
||||
|
||||
// Ping endpoint
|
||||
app.Get("/ping", func(c *fiber.Ctx) error {
|
||||
return c.SendString("pong")
|
||||
})
|
||||
|
||||
// Initialize API routes
|
||||
api.Init(di, app)
|
||||
|
||||
// 404 handler
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{
|
||||
"error": "Route not found",
|
||||
"path": c.Path(),
|
||||
})
|
||||
})
|
||||
|
||||
// Get port from environment
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "3000" // Default port
|
||||
}
|
||||
|
||||
logging.Info("Starting server on port %s", port)
|
||||
logging.Info("Swagger documentation available at: http://localhost:%s/swagger/", port)
|
||||
logging.Info("Health check available at: http://localhost:%s/health", port)
|
||||
|
||||
// Start server
|
||||
if err := app.Listen(":" + port); err != nil {
|
||||
logging.Error("Failed to start server: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return app
|
||||
}
|
||||
Reference in New Issue
Block a user