// Package jwttoken encrypts and decrypts tokens using JWT package jwttoken import ( "fmt" "os" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/rs/zerolog" zlog "github.com/rs/zerolog/log" "quantex.com/qfixdpl/src/app" "quantex.com/qfixdpl/src/app/version" "quantex.com/qfixdpl/src/common/tracerr" ) const ( postFix = "_QUANTEX_SECRET_KEY" errAuthInvalidJWT = "invalid JWT token at %s" errAuthMissingTokenClaim = "missing token claim %s at %s" errAuthTokenNotSigned = "token could not be signed at Encrypt: %w" errAuthMissingToken = "missing authentication token for service at %s: %s" errAuthExpiredToken = "authentication token expired at %s" errAuthExpClaim = "expiration claim is missing or invalid at %s" errAuthSecretKeyNotFound = "secret key %s not found in environment" errAuthTokenNotParsable = "token could not be parsed at %s:%w " errAuthServiceNameEmpty = "service name cannot be empty at %s - received: %s" claimToken = "token" claimPermissions = "permissions" claimIss = "iss" claimExp = "exp" claimIat = "iat" ) var claimsToValidate = []string{claimToken, claimIss} //nolint:gochecknoglobals // Set claims to validate var log zerolog.Logger //nolint var secretCache sync.Map //nolint:gochecknoglobals // Cache secrets after first lookup func init() { log = zlog.With().Str("gtag", "secure_token").Logger().Level(zerolog.InfoLevel) } // Validate decrypts and validates a JWT token and its claims. // It returns the AuthorizedService associated with the token if valid. // An error is returned if the token is invalid or any validation step fails. // The service parameter specifies the expected issuer of the token. func Validate(service, token string, authServices map[string]app.AuthorizedService) (*app.AuthorizedService, error) { claims, err := decrypt(service, token) if err != nil { err := tracerr.Errorf("JWT Token could not be decrypted: %w", err) log.Error().Msg(err.Error()) return nil, err } ok, err := validateJWTClaims(claims) if err != nil || !ok { err := tracerr.Errorf("invalid claims: %w", err) log.Error().Msg(err.Error()) return nil, err } auth, err := serviceAuth(claims, authServices) if err != nil { err := tracerr.Errorf("service auth validation failed: %w", err) log.Error().Msg(err.Error()) return nil, err } t, ok := claims[claimToken].(string) if !ok { err := tracerr.Errorf("token claim is not a string") log.Error().Msg(err.Error()) return nil, err } auth.Token = &t return auth, nil } // Encrypt creates a JWT token string from the provided ExtAuth information. // It returns the signed JWT token string or an error if the process fails. // The auth parameter contains the necessary information for token creation. func Encrypt(auth app.ExtAuth) (string, error) { if auth.Name == "" { err := tracerr.Errorf("auth.Name cannot be empty at Encrypt") log.Error().Msg(err.Error()) return "", err } claims := jwt.MapClaims{ claimToken: auth.Token, claimIss: version.AppName, claimIat: time.Now().Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) secret, err := findSecret(auth.Name) if secret == nil || len(secret) == 0 || err != nil { err := tracerr.Errorf(errAuthMissingToken, "Encrypt", auth.Name) log.Error().Msg(err.Error()) return "", err } signedTkn, err := token.SignedString(secret) if err != nil { err = tracerr.Errorf(errAuthTokenNotSigned, err) log.Error().Msg(err.Error()) return "", err } return signedTkn, nil } // IsJWT checks if a token string is a valid JWT format. // JWT tokens have exactly 3 parts separated by dots (header.payload.signature). func IsJWT(token string) bool { // Remove any whitespace token = strings.TrimSpace(token) // JWT has exactly 3 parts separated by dots parts := strings.Split(token, ".") if len(parts) != 3 { return false } // Each part should be non-empty (base64url encoded) for _, part := range parts { if part == "" { return false } } return true } // decrypt decrypts a JWT token string using the secret associated with the given service. // It returns the JWT claims if decryption is successful, or an error otherwise. // The service parameter specifies the expected issuer of the token. func decrypt(service, token string) (jwt.MapClaims, error) { if service == "" { err := tracerr.Errorf(errAuthServiceNameEmpty, "Decrypt", service) log.Error().Msg(err.Error()) return jwt.MapClaims{}, err } secret, err := findSecret(service) if secret == nil || len(secret) == 0 || err != nil { err := tracerr.Errorf(errAuthMissingToken, "Decrypt", service) log.Error().Msg(err.Error()) return jwt.MapClaims{}, err } tkn, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Alg() { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } return secret, nil }) if err != nil { err = tracerr.Errorf(errAuthTokenNotParsable, "Decrypt", err) log.Error().Msg(err.Error()) return jwt.MapClaims{}, err } claims, ok := tkn.Claims.(jwt.MapClaims) if !ok { err := tracerr.Errorf(errAuthInvalidJWT, "Decrypt") log.Error().Msg(err.Error()) return jwt.MapClaims{}, err } return claims, nil } // serviceAuth validates the service based on JWT claims and authorized services. // It returns the AuthorizedService if validation is successful, or an error otherwise. // The services parameter contains the map of authorized services to validate against. func serviceAuth(claims jwt.MapClaims, services map[string]app.AuthorizedService) (out *app.AuthorizedService, err error) { issuer, ok := claims[claimIss].(string) if !ok { err := tracerr.Errorf("issuer claim is not a string") log.Error().Msg(err.Error()) return nil, err } service, ok := services[issuer] if !ok { err = tracerr.Errorf("Unknown service attempted access - issuer: %s", issuer) log.Warn().Str("issuer", issuer).Msg(err.Error()) return nil, err } // TODO: change required permission as needed ok = service.HasPermissions(app.ServicePermissionFullAccess) if !ok { err = tracerr.Errorf("Service without required permissions attempted access") log.Warn().Str("issuer", issuer).Msg(err.Error()) return nil, err } return &service, nil } // findSecret retrieves the secret key for a given application name from environment variables. // It caches the secret after the first lookup for performance. // It returns the secret as a byte slice or an error if not found. func findSecret(appName string) ([]byte, error) { key := strings.ToUpper(appName) + postFix if k, ok := secretCache.Load(key); ok { if b, ok := k.([]byte); ok { return b, nil } } secret := os.Getenv(key) if secret == "" { err := tracerr.Errorf(errAuthSecretKeyNotFound, key) return nil, err } secretCache.Store(key, []byte(secret)) return []byte(secret), nil } // validateJWTClaims checks the presence and validity of required JWT claims. // It returns true if all claims are valid, or an error otherwise. func validateJWTClaims(claims jwt.MapClaims) (ok bool, err error) { for _, claim := range claimsToValidate { ok, err = validateClaims(claims, claim) if err != nil || !ok { err := tracerr.Errorf("error %s token at ValidateClaims: %w", claim, err) log.Error().Msg(err.Error()) return false, err } } return true, nil } // validateClaims validates a specific JWT claim based on its type. func validateClaims(claims jwt.MapClaims, claim string) (ok bool, err error) { switch claim { case claimExp: return validateExpiration(claims) case claimToken: return validateToken(claims) case claimIss: return validateIssuer(claims) default: err := fmt.Errorf("unknown claim %s at validateTokenClaims", claim) log.Error().Msg(err.Error()) return false, err } } // TODO: when needed add claimExp to claimsToValidate var to implement. func validateExpiration(claims jwt.MapClaims) (ok bool, err error) { exp, ok := claims[claimExp].(float64) if !ok { err := fmt.Errorf(errAuthExpClaim, "validateExpiration") log.Error().Msg(err.Error()) return false, nil } if int64(exp) < time.Now().Unix() { err := fmt.Errorf(errAuthExpiredToken, "validateExpiration") log.Error().Msg(err.Error()) return false, err } return true, nil } // validateToken checks for the presence of the token claim in the JWT claims. func validateToken(claims jwt.MapClaims) (ok bool, err error) { _, ok = claims[claimToken].(string) if !ok { err = tracerr.Errorf(errAuthMissingTokenClaim, claimToken, "validateToken") log.Error().Msg(err.Error()) return false, err } return true, nil } // validateIssuer checks for the presence of the issuer claim in the JWT claims. func validateIssuer(claims jwt.MapClaims) (ok bool, err error) { _, ok = claims[claimIss].(string) if !ok { err = tracerr.Errorf(errAuthMissingTokenClaim, claimIss, "validateIssuer") log.Error().Msg(err.Error()) return false, err } return true, nil }