init bootstrap
This commit is contained in:
365
local/utl/jwt/jwt.go
Normal file
365
local/utl/jwt/jwt.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// Claims represents the JWT claims
|
||||
type Claims struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Roles []string `json:"roles"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// TokenPair represents access and refresh tokens
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
TokenType string `json:"tokenType"`
|
||||
}
|
||||
|
||||
var (
|
||||
jwtSecret []byte
|
||||
accessTokenTTL time.Duration
|
||||
refreshTokenTTL time.Duration
|
||||
issuer string
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
ErrInvalidClaims = errors.New("invalid token claims")
|
||||
)
|
||||
|
||||
// Initialize initializes the JWT package with configuration
|
||||
func Initialize() error {
|
||||
// Get JWT secret from environment
|
||||
secret := os.Getenv("JWT_SECRET")
|
||||
if secret == "" {
|
||||
return errors.New("JWT_SECRET environment variable is required")
|
||||
}
|
||||
jwtSecret = []byte(secret)
|
||||
|
||||
// Get token TTL from environment or use defaults
|
||||
accessTTLStr := os.Getenv("JWT_ACCESS_TTL_HOURS")
|
||||
if accessTTLStr == "" {
|
||||
accessTTLStr = "24" // 24 hours default
|
||||
}
|
||||
accessHours, err := strconv.Atoi(accessTTLStr)
|
||||
if err != nil {
|
||||
accessHours = 24
|
||||
}
|
||||
accessTokenTTL = time.Duration(accessHours) * time.Hour
|
||||
|
||||
refreshTTLStr := os.Getenv("JWT_REFRESH_TTL_DAYS")
|
||||
if refreshTTLStr == "" {
|
||||
refreshTTLStr = "7" // 7 days default
|
||||
}
|
||||
refreshDays, err := strconv.Atoi(refreshTTLStr)
|
||||
if err != nil {
|
||||
refreshDays = 7
|
||||
}
|
||||
refreshTokenTTL = time.Duration(refreshDays) * 24 * time.Hour
|
||||
|
||||
// Get issuer from environment
|
||||
issuer = os.Getenv("JWT_ISSUER")
|
||||
if issuer == "" {
|
||||
issuer = "omega-server"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateTokenPair generates both access and refresh tokens
|
||||
func GenerateTokenPair(userID, email, username string, roles []string) (*TokenPair, error) {
|
||||
if len(jwtSecret) == 0 {
|
||||
if err := Initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
accessExpiresAt := now.Add(accessTokenTTL)
|
||||
refreshExpiresAt := now.Add(refreshTokenTTL)
|
||||
|
||||
// Create access token claims
|
||||
accessClaims := &Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Username: username,
|
||||
Roles: roles,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(accessExpiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: issuer,
|
||||
Subject: userID,
|
||||
ID: generateJTI(),
|
||||
},
|
||||
}
|
||||
|
||||
// Create refresh token claims (minimal data)
|
||||
refreshClaims := &Claims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(refreshExpiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: issuer,
|
||||
Subject: userID,
|
||||
ID: generateJTI(),
|
||||
},
|
||||
}
|
||||
|
||||
// Generate access token
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessTokenString, err := accessToken.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate refresh token
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshTokenString, err := refreshToken.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessTokenString,
|
||||
RefreshToken: refreshTokenString,
|
||||
ExpiresAt: accessExpiresAt,
|
||||
TokenType: "Bearer",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token and returns the claims
|
||||
func ValidateToken(tokenString string) (*Claims, error) {
|
||||
if len(jwtSecret) == 0 {
|
||||
if err := Initialize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Validate signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("invalid signing method")
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrExpiredToken
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
if claims.UserID == "" {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RefreshToken generates a new access token from a valid refresh token
|
||||
func RefreshToken(refreshTokenString string) (*TokenPair, error) {
|
||||
// Validate refresh token
|
||||
claims, err := ValidateToken(refreshTokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For refresh tokens, we only have minimal user data
|
||||
// In a real application, you might want to fetch fresh user data from the database
|
||||
return GenerateTokenPair(claims.UserID, claims.Email, claims.Username, claims.Roles)
|
||||
}
|
||||
|
||||
// ExtractTokenFromHeader extracts JWT token from Authorization header
|
||||
func ExtractTokenFromHeader(authHeader string) (string, error) {
|
||||
if authHeader == "" {
|
||||
return "", errors.New("authorization header is required")
|
||||
}
|
||||
|
||||
const bearerPrefix = "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
return "", errors.New("invalid authorization header format")
|
||||
}
|
||||
|
||||
token := authHeader[len(bearerPrefix):]
|
||||
if token == "" {
|
||||
return "", errors.New("token is required")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if a token is expired without validating signature
|
||||
func IsTokenExpired(tokenString string) bool {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
return claims.ExpiresAt.Before(time.Now())
|
||||
}
|
||||
|
||||
// GetTokenClaims extracts claims from token without validating signature (use carefully)
|
||||
func GetTokenClaims(tokenString string) (*Claims, error) {
|
||||
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &Claims{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RevokeToken adds token to revocation list (implement with your chosen storage)
|
||||
func RevokeToken(tokenString string) error {
|
||||
// In a real implementation, you would store revoked tokens in a database or cache
|
||||
// with their JTI (JWT ID) and expiration time for efficient cleanup
|
||||
|
||||
claims, err := GetTokenClaims(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the JTI in your revocation storage
|
||||
// Example: revokedTokens[claims.ID] = claims.ExpiresAt
|
||||
|
||||
_ = claims // Placeholder to avoid unused variable error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsTokenRevoked checks if a token has been revoked (implement with your chosen storage)
|
||||
func IsTokenRevoked(tokenString string) bool {
|
||||
// In a real implementation, check if the token's JTI is in the revocation list
|
||||
|
||||
claims, err := GetTokenClaims(tokenString)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check revocation storage
|
||||
// Example: _, revoked := revokedTokens[claims.ID]
|
||||
// return revoked
|
||||
|
||||
_ = claims // Placeholder to avoid unused variable error
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// generateJTI generates a unique JWT ID
|
||||
func generateJTI() string {
|
||||
// In a real implementation, use a proper UUID generator
|
||||
return strconv.FormatInt(time.Now().UnixNano(), 36)
|
||||
}
|
||||
|
||||
// GetTokenExpirationTime returns the expiration time of a token
|
||||
func GetTokenExpirationTime(tokenString string) (time.Time, error) {
|
||||
claims, err := GetTokenClaims(tokenString)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
if claims.ExpiresAt == nil {
|
||||
return time.Time{}, errors.New("token has no expiration time")
|
||||
}
|
||||
|
||||
return claims.ExpiresAt.Time, nil
|
||||
}
|
||||
|
||||
// GetTokenRemainingTime returns the remaining time before token expires
|
||||
func GetTokenRemainingTime(tokenString string) (time.Duration, error) {
|
||||
expirationTime, err := GetTokenExpirationTime(tokenString)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
remaining := time.Until(expirationTime)
|
||||
if remaining < 0 {
|
||||
return 0, ErrExpiredToken
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// HasRole checks if the token claims contain a specific role
|
||||
func (c *Claims) HasRole(role string) bool {
|
||||
for _, r := range c.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAnyRole checks if the token claims contain any of the specified roles
|
||||
func (c *Claims) HasAnyRole(roles []string) bool {
|
||||
for _, role := range roles {
|
||||
if c.HasRole(role) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAllRoles checks if the token claims contain all of the specified roles
|
||||
func (c *Claims) HasAllRoles(roles []string) bool {
|
||||
for _, role := range roles {
|
||||
if !c.HasRole(role) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsAdmin checks if the user has admin role
|
||||
func (c *Claims) IsAdmin() bool {
|
||||
return c.HasRole("admin")
|
||||
}
|
||||
|
||||
// GetUserInfo returns basic user information from claims
|
||||
func (c *Claims) GetUserInfo() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"userId": c.UserID,
|
||||
"email": c.Email,
|
||||
"username": c.Username,
|
||||
"roles": c.Roles,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates a JWT token for a user (backward compatibility)
|
||||
func GenerateToken(userID, email, username string, roles []string) (string, error) {
|
||||
tokenPair, err := GenerateTokenPair(userID, email, username, roles)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tokenPair.AccessToken, nil
|
||||
}
|
||||
Reference in New Issue
Block a user