package jwt import ( "errors" "os" "strconv" "time" "github.com/golang-jwt/jwt/v4" ) // Claims represents the JWT claims type Claims struct { UserID string `json:"userId"` Email string `json:"email"` Username string `json:"username"` Roles []string `json:"roles"` jwt.RegisteredClaims } // TokenPair represents access and refresh tokens type TokenPair struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ExpiresAt time.Time `json:"expiresAt"` TokenType string `json:"tokenType"` } var ( jwtSecret []byte accessTokenTTL time.Duration refreshTokenTTL time.Duration issuer string ErrInvalidToken = errors.New("invalid token") ErrExpiredToken = errors.New("token has expired") ErrInvalidClaims = errors.New("invalid token claims") ) // Initialize initializes the JWT package with configuration func Initialize() error { // Get JWT secret from environment secret := os.Getenv("JWT_SECRET") if secret == "" { return errors.New("JWT_SECRET environment variable is required") } jwtSecret = []byte(secret) // Get token TTL from environment or use defaults accessTTLStr := os.Getenv("JWT_ACCESS_TTL_HOURS") if accessTTLStr == "" { accessTTLStr = "24" // 24 hours default } accessHours, err := strconv.Atoi(accessTTLStr) if err != nil { accessHours = 24 } accessTokenTTL = time.Duration(accessHours) * time.Hour refreshTTLStr := os.Getenv("JWT_REFRESH_TTL_DAYS") if refreshTTLStr == "" { refreshTTLStr = "7" // 7 days default } refreshDays, err := strconv.Atoi(refreshTTLStr) if err != nil { refreshDays = 7 } refreshTokenTTL = time.Duration(refreshDays) * 24 * time.Hour // Get issuer from environment issuer = os.Getenv("JWT_ISSUER") if issuer == "" { issuer = "omega-server" } return nil } // GenerateTokenPair generates both access and refresh tokens func GenerateTokenPair(userID, email, username string, roles []string) (*TokenPair, error) { if len(jwtSecret) == 0 { if err := Initialize(); err != nil { return nil, err } } now := time.Now() accessExpiresAt := now.Add(accessTokenTTL) refreshExpiresAt := now.Add(refreshTokenTTL) // Create access token claims accessClaims := &Claims{ UserID: userID, Email: email, Username: username, Roles: roles, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(accessExpiresAt), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), Issuer: issuer, Subject: userID, ID: generateJTI(), }, } // Create refresh token claims (minimal data) refreshClaims := &Claims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(refreshExpiresAt), IssuedAt: jwt.NewNumericDate(now), NotBefore: jwt.NewNumericDate(now), Issuer: issuer, Subject: userID, ID: generateJTI(), }, } // Generate access token accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims) accessTokenString, err := accessToken.SignedString(jwtSecret) if err != nil { return nil, err } // Generate refresh token refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims) refreshTokenString, err := refreshToken.SignedString(jwtSecret) if err != nil { return nil, err } return &TokenPair{ AccessToken: accessTokenString, RefreshToken: refreshTokenString, ExpiresAt: accessExpiresAt, TokenType: "Bearer", }, nil } // ValidateToken validates a JWT token and returns the claims func ValidateToken(tokenString string) (*Claims, error) { if len(jwtSecret) == 0 { if err := Initialize(); err != nil { return nil, err } } token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Validate signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("invalid signing method") } return jwtSecret, nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, ErrExpiredToken } return nil, ErrInvalidToken } claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, ErrInvalidClaims } // Additional validation if claims.UserID == "" { return nil, ErrInvalidClaims } return claims, nil } // RefreshToken generates a new access token from a valid refresh token func RefreshToken(refreshTokenString string) (*TokenPair, error) { // Validate refresh token claims, err := ValidateToken(refreshTokenString) if err != nil { return nil, err } // For refresh tokens, we only have minimal user data // In a real application, you might want to fetch fresh user data from the database return GenerateTokenPair(claims.UserID, claims.Email, claims.Username, claims.Roles) } // ExtractTokenFromHeader extracts JWT token from Authorization header func ExtractTokenFromHeader(authHeader string) (string, error) { if authHeader == "" { return "", errors.New("authorization header is required") } const bearerPrefix = "Bearer " if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { return "", errors.New("invalid authorization header format") } token := authHeader[len(bearerPrefix):] if token == "" { return "", errors.New("token is required") } return token, nil } // IsTokenExpired checks if a token is expired without validating signature func IsTokenExpired(tokenString string) bool { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return jwtSecret, nil }) if err != nil { return true } claims, ok := token.Claims.(*Claims) if !ok { return true } return claims.ExpiresAt.Before(time.Now()) } // GetTokenClaims extracts claims from token without validating signature (use carefully) func GetTokenClaims(tokenString string) (*Claims, error) { token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &Claims{}) if err != nil { return nil, err } claims, ok := token.Claims.(*Claims) if !ok { return nil, ErrInvalidClaims } return claims, nil } // RevokeToken adds token to revocation list (implement with your chosen storage) func RevokeToken(tokenString string) error { // In a real implementation, you would store revoked tokens in a database or cache // with their JTI (JWT ID) and expiration time for efficient cleanup claims, err := GetTokenClaims(tokenString) if err != nil { return err } // Store the JTI in your revocation storage // Example: revokedTokens[claims.ID] = claims.ExpiresAt _ = claims // Placeholder to avoid unused variable error return nil } // IsTokenRevoked checks if a token has been revoked (implement with your chosen storage) func IsTokenRevoked(tokenString string) bool { // In a real implementation, check if the token's JTI is in the revocation list claims, err := GetTokenClaims(tokenString) if err != nil { return true } // Check revocation storage // Example: _, revoked := revokedTokens[claims.ID] // return revoked _ = claims // Placeholder to avoid unused variable error return false } // generateJTI generates a unique JWT ID func generateJTI() string { // In a real implementation, use a proper UUID generator return strconv.FormatInt(time.Now().UnixNano(), 36) } // GetTokenExpirationTime returns the expiration time of a token func GetTokenExpirationTime(tokenString string) (time.Time, error) { claims, err := GetTokenClaims(tokenString) if err != nil { return time.Time{}, err } if claims.ExpiresAt == nil { return time.Time{}, errors.New("token has no expiration time") } return claims.ExpiresAt.Time, nil } // GetTokenRemainingTime returns the remaining time before token expires func GetTokenRemainingTime(tokenString string) (time.Duration, error) { expirationTime, err := GetTokenExpirationTime(tokenString) if err != nil { return 0, err } remaining := time.Until(expirationTime) if remaining < 0 { return 0, ErrExpiredToken } return remaining, nil } // HasRole checks if the token claims contain a specific role func (c *Claims) HasRole(role string) bool { for _, r := range c.Roles { if r == role { return true } } return false } // HasAnyRole checks if the token claims contain any of the specified roles func (c *Claims) HasAnyRole(roles []string) bool { for _, role := range roles { if c.HasRole(role) { return true } } return false } // HasAllRoles checks if the token claims contain all of the specified roles func (c *Claims) HasAllRoles(roles []string) bool { for _, role := range roles { if !c.HasRole(role) { return false } } return true } // IsAdmin checks if the user has admin role func (c *Claims) IsAdmin() bool { return c.HasRole("admin") } // GetUserInfo returns basic user information from claims func (c *Claims) GetUserInfo() map[string]interface{} { return map[string]interface{}{ "userId": c.UserID, "email": c.Email, "username": c.Username, "roles": c.Roles, } } // GenerateToken generates a JWT token for a user (backward compatibility) func GenerateToken(userID, email, username string, roles []string) (string, error) { tokenPair, err := GenerateTokenPair(userID, email, username, roles) if err != nil { return "", err } return tokenPair.AccessToken, nil }