Files
cloud-services/pkg/jwt/jwt_validator.go

102 lines
3.4 KiB
Go

package jwt
import (
"context"
"encoding/json"
"time"
"fiskerinc.com/modules/logger"
"fiskerinc.com/modules/utils/envtool"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"github.com/pkg/errors"
)
const AuthenticationHeader = "Authorization"
const invalidTokenError = "invalid token"
const expiredTokenError = "token expired"
// default jwk for https://cognito-idp.us-west-2.amazonaws.com/us-west-2_AWwjLXym2/.well-known/jwks.json
var DefaultJWK = `{"keys":[{"alg":"RS256","e":"AQAB","kid":"eQ3ndRKiE/O8UFyDqlb3tKTsXm4K5O2W85wwUi3OkMg=","kty":"RSA","n":"pVeGfTzlCvcnzUE4f7LsVDhzsZbGdAn6q1LH3DSwqFF6Xw-c6z8AGV744_qvxRrDlmQs85cXPJHh2AVKJQnWBipp6EUWO5TEdMS_0cgoTk1Gr3CagUnYBZwm53HIUC8bMuWx0C6FQWcnmleNQbWR_k-zipsPbZw2sYAtSWRVGfjG6Gwo4wZx0spBk9hq3ovG5mVxnItnKJYWyx3V_ZKKa5r5ImItJa1AwaxoZxsO13NMOPTed89iSbK_IR_Db8pX6STgl6pa6YYSvI1-phBt_PLjTz2gusRj897sHxJYga5KfNgbvNkeHdaDljwilT4IKDZq1hzIrmaPrUKApb0e9w","use":"sig"},{"alg":"RS256","e":"AQAB","kid":"jIz0QTcsKCT+hxGz2S0+ChPyN7w8riP/l6mqzAXRl6o=","kty":"RSA","n":"yDqFnw52wraJImOT5rCPL2www0pRglnSS-GPG6kZMqos7KHqcO5pVD020_5g2OefK6Gs0ndUI3eDOeBwASKeZuoezAgu9D9whFHJI6-_oIiz2af3ahodRISnhFAbwcvU4i8_M6OWATVaTU5aODAcM_8q1aS-Rfp6zY9rrlaJ6RmCdYeVNue4nvS97bOrpTXmFBB2fAzbhWSq0axmWZWBFyMO12FFMvT_dCaL1dzBOEzNQU03tKsUa0WEqNs169utuo9TydX9hhjpnDtqYjIEvyOFTAnU8IldX_iiWbnR1-8BHeyqomMQFIjQCTRkLReKYDAyrVF4cFah-BDYQiluCw","use":"sig"}]}`
type JWTValidatorInterface interface {
ValidateToken(token string) (map[string]interface{}, error)
ValidateError(token string) (err error)
SetKeys(data string)
DisableExpireCheck(disable bool)
}
func NewJWTValidator(jwkUrl string) JWTValidatorInterface {
return &JWTValidator{jwkUrl: jwkUrl}
}
type JWTValidator struct {
disableExpireCheck bool
jwkUrl string
keys jwk.Set
option jwt.ParseOption
}
// ValidateToken validates a token
func (v *JWTValidator) ValidateToken(token string) (map[string]interface{}, error) {
result, err := jwt.ParseString(token, v.getKeySetOption(), jwt.InferAlgorithmFromKey(true))
if err != nil {
logger.Info().Err(err).Send()
return nil, errors.New(invalidTokenError)
}
if !v.disableExpireCheck && time.Now().After(result.Expiration()) {
return nil, errors.New(expiredTokenError)
}
return result.PrivateClaims(), nil
}
// returns the original validation error with all the tech details
func (v *JWTValidator) ValidateError(token string) (err error) {
_, err = jwt.ParseString(token, v.getKeySetOption())
return
}
func (v *JWTValidator) getKeySetOption() jwt.ParseOption {
if v.option == nil {
v.option = jwt.WithKeySet(v.getKeys())
logger.Info().Msgf("getKeySetOption %v", v.option)
}
return v.option
}
func (v *JWTValidator) getKeys() jwk.Set {
if v.keys == nil {
var err error
if len(v.jwkUrl) == 0 {
v.jwkUrl = envtool.GetEnv("JWK_URL", "NOT_SET")
}
if v.jwkUrl == "NOT_SET" {
logger.Info().Msg("getKeys no jwk url, using default")
v.SetKeys(DefaultJWK)
} else {
v.keys, err = jwk.Fetch(context.Background(), v.jwkUrl)
if err != nil {
logger.Error().Err(errors.WithStack(err)).Send()
}
logger.Info().Msgf("getKeys %v", v.keys)
}
}
return v.keys
}
// Only use this for unit tests to disable expire check on token
func (v *JWTValidator) DisableExpireCheck(disable bool) {
v.disableExpireCheck = disable
}
// SetKeys sets JWK keys from JSON data
func (v *JWTValidator) SetKeys(data string) {
v.keys = jwk.NewSet()
json.Unmarshal([]byte(data), &v.keys)
}