package jwt import ( "context" "encoding/json" "time" "github.com/fiskerinc/cloud-services/pkg/logger" "github.com/fiskerinc/cloud-services/pkg/utils/envtool" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwt" "github.com/pkg/errors" ) const AuthenticationHeader = "Authorization" const invalidTokenError = "invalid token" const expiredTokenError = "token expired" // default jwk for https://cognito-idp.us-west-2.amazonaws.com/us-west-2_AWwjLXym2/.well-known/jwks.json var DefaultJWK = `{"keys":[{"alg":"RS256","e":"AQAB","kid":"eQ3ndRKiE/O8UFyDqlb3tKTsXm4K5O2W85wwUi3OkMg=","kty":"RSA","n":"pVeGfTzlCvcnzUE4f7LsVDhzsZbGdAn6q1LH3DSwqFF6Xw-c6z8AGV744_qvxRrDlmQs85cXPJHh2AVKJQnWBipp6EUWO5TEdMS_0cgoTk1Gr3CagUnYBZwm53HIUC8bMuWx0C6FQWcnmleNQbWR_k-zipsPbZw2sYAtSWRVGfjG6Gwo4wZx0spBk9hq3ovG5mVxnItnKJYWyx3V_ZKKa5r5ImItJa1AwaxoZxsO13NMOPTed89iSbK_IR_Db8pX6STgl6pa6YYSvI1-phBt_PLjTz2gusRj897sHxJYga5KfNgbvNkeHdaDljwilT4IKDZq1hzIrmaPrUKApb0e9w","use":"sig"},{"alg":"RS256","e":"AQAB","kid":"jIz0QTcsKCT+hxGz2S0+ChPyN7w8riP/l6mqzAXRl6o=","kty":"RSA","n":"yDqFnw52wraJImOT5rCPL2www0pRglnSS-GPG6kZMqos7KHqcO5pVD020_5g2OefK6Gs0ndUI3eDOeBwASKeZuoezAgu9D9whFHJI6-_oIiz2af3ahodRISnhFAbwcvU4i8_M6OWATVaTU5aODAcM_8q1aS-Rfp6zY9rrlaJ6RmCdYeVNue4nvS97bOrpTXmFBB2fAzbhWSq0axmWZWBFyMO12FFMvT_dCaL1dzBOEzNQU03tKsUa0WEqNs169utuo9TydX9hhjpnDtqYjIEvyOFTAnU8IldX_iiWbnR1-8BHeyqomMQFIjQCTRkLReKYDAyrVF4cFah-BDYQiluCw","use":"sig"}]}` type JWTValidatorInterface interface { ValidateToken(token string) (map[string]interface{}, error) ValidateError(token string) (err error) SetKeys(data string) DisableExpireCheck(disable bool) } func NewJWTValidator(jwkUrl string) JWTValidatorInterface { return &JWTValidator{jwkUrl: jwkUrl} } type JWTValidator struct { disableExpireCheck bool jwkUrl string keys jwk.Set option jwt.ParseOption } // ValidateToken validates a token func (v *JWTValidator) ValidateToken(token string) (map[string]interface{}, error) { result, err := jwt.ParseString(token, v.getKeySetOption(), jwt.InferAlgorithmFromKey(true)) if err != nil { logger.Info().Err(err).Send() return nil, errors.New(invalidTokenError) } if !v.disableExpireCheck && time.Now().After(result.Expiration()) { return nil, errors.New(expiredTokenError) } return result.PrivateClaims(), nil } // returns the original validation error with all the tech details func (v *JWTValidator) ValidateError(token string) (err error) { _, err = jwt.ParseString(token, v.getKeySetOption()) return } func (v *JWTValidator) getKeySetOption() jwt.ParseOption { if v.option == nil { v.option = jwt.WithKeySet(v.getKeys()) logger.Info().Msgf("getKeySetOption %v", v.option) } return v.option } func (v *JWTValidator) getKeys() jwk.Set { if v.keys == nil { var err error if len(v.jwkUrl) == 0 { v.jwkUrl = envtool.GetEnv("JWK_URL", "NOT_SET") } if v.jwkUrl == "NOT_SET" { logger.Info().Msg("getKeys no jwk url, using default") v.SetKeys(DefaultJWK) } else { v.keys, err = jwk.Fetch(context.Background(), v.jwkUrl) if err != nil { logger.Error().Err(errors.WithStack(err)).Send() } logger.Info().Msgf("getKeys %v", v.keys) } } return v.keys } // Only use this for unit tests to disable expire check on token func (v *JWTValidator) DisableExpireCheck(disable bool) { v.disableExpireCheck = disable } // SetKeys sets JWK keys from JSON data func (v *JWTValidator) SetKeys(data string) { v.keys = jwk.NewSet() json.Unmarshal([]byte(data), &v.keys) }