366 lines
9.1 KiB
Go
366 lines
9.1 KiB
Go
package jwt
|
|
|
|
import (
|
|
"errors"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v4"
|
|
)
|
|
|
|
// Claims represents the JWT claims
|
|
type Claims struct {
|
|
UserID string `json:"userId"`
|
|
Email string `json:"email"`
|
|
Username string `json:"username"`
|
|
Roles []string `json:"roles"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
// TokenPair represents access and refresh tokens
|
|
type TokenPair struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ExpiresAt time.Time `json:"expiresAt"`
|
|
TokenType string `json:"tokenType"`
|
|
}
|
|
|
|
var (
|
|
jwtSecret []byte
|
|
accessTokenTTL time.Duration
|
|
refreshTokenTTL time.Duration
|
|
issuer string
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
ErrExpiredToken = errors.New("token has expired")
|
|
ErrInvalidClaims = errors.New("invalid token claims")
|
|
)
|
|
|
|
// Initialize initializes the JWT package with configuration
|
|
func Initialize() error {
|
|
// Get JWT secret from environment
|
|
secret := os.Getenv("JWT_SECRET")
|
|
if secret == "" {
|
|
return errors.New("JWT_SECRET environment variable is required")
|
|
}
|
|
jwtSecret = []byte(secret)
|
|
|
|
// Get token TTL from environment or use defaults
|
|
accessTTLStr := os.Getenv("JWT_ACCESS_TTL_HOURS")
|
|
if accessTTLStr == "" {
|
|
accessTTLStr = "24" // 24 hours default
|
|
}
|
|
accessHours, err := strconv.Atoi(accessTTLStr)
|
|
if err != nil {
|
|
accessHours = 24
|
|
}
|
|
accessTokenTTL = time.Duration(accessHours) * time.Hour
|
|
|
|
refreshTTLStr := os.Getenv("JWT_REFRESH_TTL_DAYS")
|
|
if refreshTTLStr == "" {
|
|
refreshTTLStr = "7" // 7 days default
|
|
}
|
|
refreshDays, err := strconv.Atoi(refreshTTLStr)
|
|
if err != nil {
|
|
refreshDays = 7
|
|
}
|
|
refreshTokenTTL = time.Duration(refreshDays) * 24 * time.Hour
|
|
|
|
// Get issuer from environment
|
|
issuer = os.Getenv("JWT_ISSUER")
|
|
if issuer == "" {
|
|
issuer = "omega-server"
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateTokenPair generates both access and refresh tokens
|
|
func GenerateTokenPair(userID, email, username string, roles []string) (*TokenPair, error) {
|
|
if len(jwtSecret) == 0 {
|
|
if err := Initialize(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
now := time.Now()
|
|
accessExpiresAt := now.Add(accessTokenTTL)
|
|
refreshExpiresAt := now.Add(refreshTokenTTL)
|
|
|
|
// Create access token claims
|
|
accessClaims := &Claims{
|
|
UserID: userID,
|
|
Email: email,
|
|
Username: username,
|
|
Roles: roles,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(accessExpiresAt),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
Issuer: issuer,
|
|
Subject: userID,
|
|
ID: generateJTI(),
|
|
},
|
|
}
|
|
|
|
// Create refresh token claims (minimal data)
|
|
refreshClaims := &Claims{
|
|
UserID: userID,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
ExpiresAt: jwt.NewNumericDate(refreshExpiresAt),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
NotBefore: jwt.NewNumericDate(now),
|
|
Issuer: issuer,
|
|
Subject: userID,
|
|
ID: generateJTI(),
|
|
},
|
|
}
|
|
|
|
// Generate access token
|
|
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
|
accessTokenString, err := accessToken.SignedString(jwtSecret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Generate refresh token
|
|
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
|
refreshTokenString, err := refreshToken.SignedString(jwtSecret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &TokenPair{
|
|
AccessToken: accessTokenString,
|
|
RefreshToken: refreshTokenString,
|
|
ExpiresAt: accessExpiresAt,
|
|
TokenType: "Bearer",
|
|
}, nil
|
|
}
|
|
|
|
// ValidateToken validates a JWT token and returns the claims
|
|
func ValidateToken(tokenString string) (*Claims, error) {
|
|
if len(jwtSecret) == 0 {
|
|
if err := Initialize(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
// Validate signing method
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, errors.New("invalid signing method")
|
|
}
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil {
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
return nil, ErrExpiredToken
|
|
}
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok || !token.Valid {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
// Additional validation
|
|
if claims.UserID == "" {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// RefreshToken generates a new access token from a valid refresh token
|
|
func RefreshToken(refreshTokenString string) (*TokenPair, error) {
|
|
// Validate refresh token
|
|
claims, err := ValidateToken(refreshTokenString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// For refresh tokens, we only have minimal user data
|
|
// In a real application, you might want to fetch fresh user data from the database
|
|
return GenerateTokenPair(claims.UserID, claims.Email, claims.Username, claims.Roles)
|
|
}
|
|
|
|
// ExtractTokenFromHeader extracts JWT token from Authorization header
|
|
func ExtractTokenFromHeader(authHeader string) (string, error) {
|
|
if authHeader == "" {
|
|
return "", errors.New("authorization header is required")
|
|
}
|
|
|
|
const bearerPrefix = "Bearer "
|
|
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
|
return "", errors.New("invalid authorization header format")
|
|
}
|
|
|
|
token := authHeader[len(bearerPrefix):]
|
|
if token == "" {
|
|
return "", errors.New("token is required")
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// IsTokenExpired checks if a token is expired without validating signature
|
|
func IsTokenExpired(tokenString string) bool {
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
return jwtSecret, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return true
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok {
|
|
return true
|
|
}
|
|
|
|
return claims.ExpiresAt.Before(time.Now())
|
|
}
|
|
|
|
// GetTokenClaims extracts claims from token without validating signature (use carefully)
|
|
func GetTokenClaims(tokenString string) (*Claims, error) {
|
|
token, _, err := new(jwt.Parser).ParseUnverified(tokenString, &Claims{})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claims, ok := token.Claims.(*Claims)
|
|
if !ok {
|
|
return nil, ErrInvalidClaims
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// RevokeToken adds token to revocation list (implement with your chosen storage)
|
|
func RevokeToken(tokenString string) error {
|
|
// In a real implementation, you would store revoked tokens in a database or cache
|
|
// with their JTI (JWT ID) and expiration time for efficient cleanup
|
|
|
|
claims, err := GetTokenClaims(tokenString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Store the JTI in your revocation storage
|
|
// Example: revokedTokens[claims.ID] = claims.ExpiresAt
|
|
|
|
_ = claims // Placeholder to avoid unused variable error
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsTokenRevoked checks if a token has been revoked (implement with your chosen storage)
|
|
func IsTokenRevoked(tokenString string) bool {
|
|
// In a real implementation, check if the token's JTI is in the revocation list
|
|
|
|
claims, err := GetTokenClaims(tokenString)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
|
|
// Check revocation storage
|
|
// Example: _, revoked := revokedTokens[claims.ID]
|
|
// return revoked
|
|
|
|
_ = claims // Placeholder to avoid unused variable error
|
|
|
|
return false
|
|
}
|
|
|
|
// generateJTI generates a unique JWT ID
|
|
func generateJTI() string {
|
|
// In a real implementation, use a proper UUID generator
|
|
return strconv.FormatInt(time.Now().UnixNano(), 36)
|
|
}
|
|
|
|
// GetTokenExpirationTime returns the expiration time of a token
|
|
func GetTokenExpirationTime(tokenString string) (time.Time, error) {
|
|
claims, err := GetTokenClaims(tokenString)
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
|
|
if claims.ExpiresAt == nil {
|
|
return time.Time{}, errors.New("token has no expiration time")
|
|
}
|
|
|
|
return claims.ExpiresAt.Time, nil
|
|
}
|
|
|
|
// GetTokenRemainingTime returns the remaining time before token expires
|
|
func GetTokenRemainingTime(tokenString string) (time.Duration, error) {
|
|
expirationTime, err := GetTokenExpirationTime(tokenString)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
remaining := time.Until(expirationTime)
|
|
if remaining < 0 {
|
|
return 0, ErrExpiredToken
|
|
}
|
|
|
|
return remaining, nil
|
|
}
|
|
|
|
// HasRole checks if the token claims contain a specific role
|
|
func (c *Claims) HasRole(role string) bool {
|
|
for _, r := range c.Roles {
|
|
if r == role {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// HasAnyRole checks if the token claims contain any of the specified roles
|
|
func (c *Claims) HasAnyRole(roles []string) bool {
|
|
for _, role := range roles {
|
|
if c.HasRole(role) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// HasAllRoles checks if the token claims contain all of the specified roles
|
|
func (c *Claims) HasAllRoles(roles []string) bool {
|
|
for _, role := range roles {
|
|
if !c.HasRole(role) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// IsAdmin checks if the user has admin role
|
|
func (c *Claims) IsAdmin() bool {
|
|
return c.HasRole("admin")
|
|
}
|
|
|
|
// GetUserInfo returns basic user information from claims
|
|
func (c *Claims) GetUserInfo() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"userId": c.UserID,
|
|
"email": c.Email,
|
|
"username": c.Username,
|
|
"roles": c.Roles,
|
|
}
|
|
}
|
|
|
|
// GenerateToken generates a JWT token for a user (backward compatibility)
|
|
func GenerateToken(userID, email, username string, roles []string) (string, error) {
|
|
tokenPair, err := GenerateTokenPair(userID, email, username, roles)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return tokenPair.AccessToken, nil
|
|
}
|