Files
omega-server/local/utl/jwt/jwt.go
Fran Jurmanović 016728532c init bootstrap
2025-07-06 15:02:09 +02:00

366 lines
9.1 KiB
Go

package jwt
import (
"errors"
"os"
"strconv"
"time"
"github.com/golang-jwt/jwt/v4"
)
// Claims represents the JWT claims
type Claims struct {
UserID string `json:"userId"`
Email string `json:"email"`
Username string `json:"username"`
Roles []string `json:"roles"`
jwt.RegisteredClaims
}
// TokenPair represents access and refresh tokens
type TokenPair struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt time.Time `json:"expiresAt"`
TokenType string `json:"tokenType"`
}
var (
jwtSecret []byte
accessTokenTTL time.Duration
refreshTokenTTL time.Duration
issuer string
ErrInvalidToken = errors.New("invalid token")
ErrExpiredToken = errors.New("token has expired")
ErrInvalidClaims = errors.New("invalid token claims")
)
// Initialize initializes the JWT package with configuration
func Initialize() error {
// Get JWT secret from environment
secret := os.Getenv("JWT_SECRET")
if secret == "" {
return errors.New("JWT_SECRET environment variable is required")
}
jwtSecret = []byte(secret)
// Get token TTL from environment or use defaults
accessTTLStr := os.Getenv("JWT_ACCESS_TTL_HOURS")
if accessTTLStr == "" {
accessTTLStr = "24" // 24 hours default
}
accessHours, err := strconv.Atoi(accessTTLStr)
if err != nil {
accessHours = 24
}
accessTokenTTL = time.Duration(accessHours) * time.Hour
refreshTTLStr := os.Getenv("JWT_REFRESH_TTL_DAYS")
if refreshTTLStr == "" {
refreshTTLStr = "7" // 7 days default
}
refreshDays, err := strconv.Atoi(refreshTTLStr)
if err != nil {
refreshDays = 7
}
refreshTokenTTL = time.Duration(refreshDays) * 24 * time.Hour
// Get issuer from environment
issuer = os.Getenv("JWT_ISSUER")
if issuer == "" {
issuer = "omega-server"
}
return nil
}
// GenerateTokenPair generates both access and refresh tokens
func GenerateTokenPair(userID, email, username string, roles []string) (*TokenPair, error) {
if len(jwtSecret) == 0 {
if err := Initialize(); err != nil {
return nil, err
}
}
now := time.Now()
accessExpiresAt := now.Add(accessTokenTTL)
refreshExpiresAt := now.Add(refreshTokenTTL)
// Create access token claims
accessClaims := &Claims{
UserID: userID,
Email: email,
Username: username,
Roles: roles,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(accessExpiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: issuer,
Subject: userID,
ID: generateJTI(),
},
}
// Create refresh token claims (minimal data)
refreshClaims := &Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(refreshExpiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: issuer,
Subject: userID,
ID: generateJTI(),
},
}
// Generate access token
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessTokenString, err := accessToken.SignedString(jwtSecret)
if err != nil {
return nil, err
}
// Generate refresh token
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshTokenString, err := refreshToken.SignedString(jwtSecret)
if err != nil {
return nil, err
}
return &TokenPair{
AccessToken: accessTokenString,
RefreshToken: refreshTokenString,
ExpiresAt: accessExpiresAt,
TokenType: "Bearer",
}, nil
}
// ValidateToken validates a JWT token and returns the claims
func ValidateToken(tokenString string) (*Claims, error) {
if len(jwtSecret) == 0 {
if err := Initialize(); err != nil {
return nil, err
}
}
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
// Validate signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("invalid signing method")
}
return jwtSecret, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrExpiredToken
}
return nil, ErrInvalidToken
}
claims, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return nil, ErrInvalidClaims
}
// Additional validation
if claims.UserID == "" {
return nil, ErrInvalidClaims
}
return claims, nil
}
// RefreshToken generates a new access token from a valid refresh token
func RefreshToken(refreshTokenString string) (*TokenPair, error) {
// Validate refresh token
claims, err := ValidateToken(refreshTokenString)
if err != nil {
return nil, err
}
// For refresh tokens, we only have minimal user data
// In a real application, you might want to fetch fresh user data from the database
return GenerateTokenPair(claims.UserID, claims.Email, claims.Username, claims.Roles)
}
// ExtractTokenFromHeader extracts JWT token from Authorization header
func ExtractTokenFromHeader(authHeader string) (string, error) {
if authHeader == "" {
return "", errors.New("authorization header is required")
}
const bearerPrefix = "Bearer "
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
return "", errors.New("invalid authorization header format")
}
token := authHeader[len(bearerPrefix):]
if token == "" {
return "", errors.New("token is required")
}
return token, nil
}
// IsTokenExpired checks if a token is expired without validating signature
func IsTokenExpired(tokenString string) bool {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return jwtSecret, nil
})
if err != nil {
return true
}
claims, ok := token.Claims.(*Claims)
if !ok {
return true
}
return claims.ExpiresAt.Before(time.Now())
}
// GetTokenClaims extracts claims from token without validating signature (use carefully)
func GetTokenClaims(tokenString string) (*Claims, error) {
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &Claims{})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*Claims)
if !ok {
return nil, ErrInvalidClaims
}
return claims, nil
}
// RevokeToken adds token to revocation list (implement with your chosen storage)
func RevokeToken(tokenString string) error {
// In a real implementation, you would store revoked tokens in a database or cache
// with their JTI (JWT ID) and expiration time for efficient cleanup
claims, err := GetTokenClaims(tokenString)
if err != nil {
return err
}
// Store the JTI in your revocation storage
// Example: revokedTokens[claims.ID] = claims.ExpiresAt
_ = claims // Placeholder to avoid unused variable error
return nil
}
// IsTokenRevoked checks if a token has been revoked (implement with your chosen storage)
func IsTokenRevoked(tokenString string) bool {
// In a real implementation, check if the token's JTI is in the revocation list
claims, err := GetTokenClaims(tokenString)
if err != nil {
return true
}
// Check revocation storage
// Example: _, revoked := revokedTokens[claims.ID]
// return revoked
_ = claims // Placeholder to avoid unused variable error
return false
}
// generateJTI generates a unique JWT ID
func generateJTI() string {
// In a real implementation, use a proper UUID generator
return strconv.FormatInt(time.Now().UnixNano(), 36)
}
// GetTokenExpirationTime returns the expiration time of a token
func GetTokenExpirationTime(tokenString string) (time.Time, error) {
claims, err := GetTokenClaims(tokenString)
if err != nil {
return time.Time{}, err
}
if claims.ExpiresAt == nil {
return time.Time{}, errors.New("token has no expiration time")
}
return claims.ExpiresAt.Time, nil
}
// GetTokenRemainingTime returns the remaining time before token expires
func GetTokenRemainingTime(tokenString string) (time.Duration, error) {
expirationTime, err := GetTokenExpirationTime(tokenString)
if err != nil {
return 0, err
}
remaining := time.Until(expirationTime)
if remaining < 0 {
return 0, ErrExpiredToken
}
return remaining, nil
}
// HasRole checks if the token claims contain a specific role
func (c *Claims) HasRole(role string) bool {
for _, r := range c.Roles {
if r == role {
return true
}
}
return false
}
// HasAnyRole checks if the token claims contain any of the specified roles
func (c *Claims) HasAnyRole(roles []string) bool {
for _, role := range roles {
if c.HasRole(role) {
return true
}
}
return false
}
// HasAllRoles checks if the token claims contain all of the specified roles
func (c *Claims) HasAllRoles(roles []string) bool {
for _, role := range roles {
if !c.HasRole(role) {
return false
}
}
return true
}
// IsAdmin checks if the user has admin role
func (c *Claims) IsAdmin() bool {
return c.HasRole("admin")
}
// GetUserInfo returns basic user information from claims
func (c *Claims) GetUserInfo() map[string]interface{} {
return map[string]interface{}{
"userId": c.UserID,
"email": c.Email,
"username": c.Username,
"roles": c.Roles,
}
}
// GenerateToken generates a JWT token for a user (backward compatibility)
func GenerateToken(userID, email, username string, roles []string) (string, error) {
tokenPair, err := GenerateTokenPair(userID, email, username, roles)
if err != nil {
return "", err
}
return tokenPair.AccessToken, nil
}