348 lines
9.0 KiB
Go
348 lines
9.0 KiB
Go
// Package jwttoken encrypts and decrypts tokens using JWT
|
|
package jwttoken
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/rs/zerolog"
|
|
zlog "github.com/rs/zerolog/log"
|
|
|
|
"quantex.com/qfixdpl/src/app"
|
|
"quantex.com/qfixdpl/src/app/version"
|
|
"quantex.com/qfixdpl/src/common/tracerr"
|
|
)
|
|
|
|
const (
|
|
postFix = "_QUANTEX_SECRET_KEY"
|
|
errAuthInvalidJWT = "invalid JWT token at %s"
|
|
errAuthMissingTokenClaim = "missing token claim %s at %s"
|
|
errAuthTokenNotSigned = "token could not be signed at Encrypt: %w"
|
|
errAuthMissingToken = "missing authentication token for service at %s: %s"
|
|
errAuthExpiredToken = "authentication token expired at %s"
|
|
errAuthExpClaim = "expiration claim is missing or invalid at %s"
|
|
errAuthSecretKeyNotFound = "secret key %s not found in environment"
|
|
errAuthTokenNotParsable = "token could not be parsed at %s:%w "
|
|
errAuthServiceNameEmpty = "service name cannot be empty at %s - received: %s"
|
|
claimToken = "token"
|
|
claimPermissions = "permissions"
|
|
claimIss = "iss"
|
|
claimExp = "exp"
|
|
claimIat = "iat"
|
|
)
|
|
|
|
var claimsToValidate = []string{claimToken, claimIss} //nolint:gochecknoglobals // Set claims to validate
|
|
|
|
var log zerolog.Logger //nolint
|
|
|
|
var secretCache sync.Map //nolint:gochecknoglobals // Cache secrets after first lookup
|
|
|
|
func init() {
|
|
log = zlog.With().Str("gtag", "secure_token").Logger().Level(zerolog.InfoLevel)
|
|
}
|
|
|
|
// Validate decrypts and validates a JWT token and its claims.
|
|
// It returns the AuthorizedService associated with the token if valid.
|
|
// An error is returned if the token is invalid or any validation step fails.
|
|
// The service parameter specifies the expected issuer of the token.
|
|
func Validate(service, token string, authServices map[string]app.AuthorizedService) (*app.AuthorizedService, error) {
|
|
claims, err := decrypt(service, token)
|
|
if err != nil {
|
|
err := tracerr.Errorf("JWT Token could not be decrypted: %w", err)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
ok, err := validateJWTClaims(claims)
|
|
if err != nil || !ok {
|
|
err := tracerr.Errorf("invalid claims: %w", err)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
auth, err := serviceAuth(claims, authServices)
|
|
if err != nil {
|
|
err := tracerr.Errorf("service auth validation failed: %w", err)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
t, ok := claims[claimToken].(string)
|
|
if !ok {
|
|
err := tracerr.Errorf("token claim is not a string")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
auth.Token = &t
|
|
|
|
return auth, nil
|
|
}
|
|
|
|
// Encrypt creates a JWT token string from the provided ExtAuth information.
|
|
// It returns the signed JWT token string or an error if the process fails.
|
|
// The auth parameter contains the necessary information for token creation.
|
|
func Encrypt(auth app.ExtAuth) (string, error) {
|
|
if auth.Name == "" {
|
|
err := tracerr.Errorf("auth.Name cannot be empty at Encrypt")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return "", err
|
|
}
|
|
|
|
claims := jwt.MapClaims{
|
|
claimToken: auth.Token,
|
|
claimIss: version.AppName,
|
|
claimIat: time.Now().Unix(),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
|
|
secret, err := findSecret(auth.Name)
|
|
if secret == nil || len(secret) == 0 || err != nil {
|
|
err := tracerr.Errorf(errAuthMissingToken, "Encrypt", auth.Name)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return "", err
|
|
}
|
|
|
|
signedTkn, err := token.SignedString(secret)
|
|
if err != nil {
|
|
err = tracerr.Errorf(errAuthTokenNotSigned, err)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return "", err
|
|
}
|
|
|
|
return signedTkn, nil
|
|
}
|
|
|
|
// IsJWT checks if a token string is a valid JWT format.
|
|
// JWT tokens have exactly 3 parts separated by dots (header.payload.signature).
|
|
func IsJWT(token string) bool {
|
|
// Remove any whitespace
|
|
token = strings.TrimSpace(token)
|
|
|
|
// JWT has exactly 3 parts separated by dots
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
return false
|
|
}
|
|
|
|
// Each part should be non-empty (base64url encoded)
|
|
for _, part := range parts {
|
|
if part == "" {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// decrypt decrypts a JWT token string using the secret associated with the given service.
|
|
// It returns the JWT claims if decryption is successful, or an error otherwise.
|
|
// The service parameter specifies the expected issuer of the token.
|
|
func decrypt(service, token string) (jwt.MapClaims, error) {
|
|
if service == "" {
|
|
err := tracerr.Errorf(errAuthServiceNameEmpty, "Decrypt", service)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return jwt.MapClaims{}, err
|
|
}
|
|
|
|
secret, err := findSecret(service)
|
|
if secret == nil || len(secret) == 0 || err != nil {
|
|
err := tracerr.Errorf(errAuthMissingToken, "Decrypt", service)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return jwt.MapClaims{}, err
|
|
}
|
|
|
|
tkn, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
|
if t.Method.Alg() != jwt.SigningMethodHS256.Alg() {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
|
}
|
|
|
|
return secret, nil
|
|
})
|
|
if err != nil {
|
|
err = tracerr.Errorf(errAuthTokenNotParsable, "Decrypt", err)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return jwt.MapClaims{}, err
|
|
}
|
|
|
|
claims, ok := tkn.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
err := tracerr.Errorf(errAuthInvalidJWT, "Decrypt")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return jwt.MapClaims{}, err
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// serviceAuth validates the service based on JWT claims and authorized services.
|
|
// It returns the AuthorizedService if validation is successful, or an error otherwise.
|
|
// The services parameter contains the map of authorized services to validate against.
|
|
func serviceAuth(claims jwt.MapClaims, services map[string]app.AuthorizedService) (out *app.AuthorizedService, err error) {
|
|
issuer, ok := claims[claimIss].(string)
|
|
if !ok {
|
|
err := tracerr.Errorf("issuer claim is not a string")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
service, ok := services[issuer]
|
|
if !ok {
|
|
err = tracerr.Errorf("Unknown service attempted access - issuer: %s", issuer)
|
|
|
|
log.Warn().Str("issuer", issuer).Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
// TODO: change required permission as needed
|
|
ok = service.HasPermissions(app.ServicePermissionFullAccess)
|
|
if !ok {
|
|
err = tracerr.Errorf("Service without required permissions attempted access")
|
|
|
|
log.Warn().Str("issuer", issuer).Msg(err.Error())
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &service, nil
|
|
}
|
|
|
|
// findSecret retrieves the secret key for a given application name from environment variables.
|
|
// It caches the secret after the first lookup for performance.
|
|
// It returns the secret as a byte slice or an error if not found.
|
|
func findSecret(appName string) ([]byte, error) {
|
|
key := strings.ToUpper(appName) + postFix
|
|
|
|
if k, ok := secretCache.Load(key); ok {
|
|
if b, ok := k.([]byte); ok {
|
|
return b, nil
|
|
}
|
|
}
|
|
|
|
secret := os.Getenv(key)
|
|
if secret == "" {
|
|
err := tracerr.Errorf(errAuthSecretKeyNotFound, key)
|
|
|
|
return nil, err
|
|
}
|
|
|
|
secretCache.Store(key, []byte(secret))
|
|
|
|
return []byte(secret), nil
|
|
}
|
|
|
|
// validateJWTClaims checks the presence and validity of required JWT claims.
|
|
// It returns true if all claims are valid, or an error otherwise.
|
|
func validateJWTClaims(claims jwt.MapClaims) (ok bool, err error) {
|
|
for _, claim := range claimsToValidate {
|
|
ok, err = validateClaims(claims, claim)
|
|
if err != nil || !ok {
|
|
err := tracerr.Errorf("error %s token at ValidateClaims: %w", claim, err)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// validateClaims validates a specific JWT claim based on its type.
|
|
func validateClaims(claims jwt.MapClaims, claim string) (ok bool, err error) {
|
|
switch claim {
|
|
case claimExp:
|
|
return validateExpiration(claims)
|
|
case claimToken:
|
|
return validateToken(claims)
|
|
case claimIss:
|
|
return validateIssuer(claims)
|
|
default:
|
|
err := fmt.Errorf("unknown claim %s at validateTokenClaims", claim)
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
// TODO: when needed add claimExp to claimsToValidate var to implement.
|
|
func validateExpiration(claims jwt.MapClaims) (ok bool, err error) {
|
|
exp, ok := claims[claimExp].(float64)
|
|
if !ok {
|
|
err := fmt.Errorf(errAuthExpClaim, "validateExpiration")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return false, nil
|
|
}
|
|
|
|
if int64(exp) < time.Now().Unix() {
|
|
err := fmt.Errorf(errAuthExpiredToken, "validateExpiration")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// validateToken checks for the presence of the token claim in the JWT claims.
|
|
func validateToken(claims jwt.MapClaims) (ok bool, err error) {
|
|
_, ok = claims[claimToken].(string)
|
|
if !ok {
|
|
err = tracerr.Errorf(errAuthMissingTokenClaim, claimToken, "validateToken")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// validateIssuer checks for the presence of the issuer claim in the JWT claims.
|
|
func validateIssuer(claims jwt.MapClaims) (ok bool, err error) {
|
|
_, ok = claims[claimIss].(string)
|
|
if !ok {
|
|
err = tracerr.Errorf(errAuthMissingTokenClaim, claimIss, "validateIssuer")
|
|
|
|
log.Error().Msg(err.Error())
|
|
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|