145 lines
3.9 KiB
Go
145 lines
3.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"fiskerinc.com/modules/common"
|
|
"fiskerinc.com/modules/db/queries"
|
|
|
|
"fiskerinc.com/modules/httpclient"
|
|
"fiskerinc.com/modules/jwt"
|
|
"fiskerinc.com/modules/logger"
|
|
"fiskerinc.com/modules/utils/envtool"
|
|
"fiskerinc.com/modules/utils"
|
|
)
|
|
|
|
var getUserURL string = envtool.GetEnv("AUTH_GET_USER", "https://dev-auth.fiskerdps.com/auth/me")
|
|
|
|
func AppendUserMiddleware(next http.HandlerFunc, apiCalls queries.APICallsInterface) http.HandlerFunc {
|
|
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := jwt.GetAuthorizationHeader(r)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("token invalid %s %s", r.Method, r.RequestURI)
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
// go to auth to get user information
|
|
req, _ := http.NewRequest("GET", getUserURL, nil)
|
|
req.Header.Set("Authorization", r.Header.Get("Authorization"))
|
|
resp, err := httpclient.Do(req)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("Unable to fetch user %s %s", r.Method, r.RequestURI)
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp != nil && resp.StatusCode != 200 {
|
|
logger.Warn().Err(err).Msgf("Unable to fetch user %s %s", r.Method, r.RequestURI)
|
|
if err != nil {
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
} else {
|
|
utils.RespError(w, http.StatusUnauthorized, resp.Status)
|
|
}
|
|
return
|
|
}
|
|
|
|
user, err := extractUserAttributes(resp)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("Unable to parse user response %s %s", r.Method, r.RequestURI)
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
err = logAPICall(user, r.RequestURI, r.Method, apiCalls)
|
|
if err != nil {
|
|
logger.Warn().Msgf("Call log %s %s '%v'", r.Method, r.RequestURI, err)
|
|
}
|
|
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, "identity", user)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
return wrapper
|
|
}
|
|
|
|
func AppendUserTokenMiddleware(next http.HandlerFunc, apiCalls queries.APICallsInterface) http.HandlerFunc {
|
|
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
|
token, err := jwt.GetAuthorizationHeader(r)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("token invalid %s %s", r.Method, r.RequestURI)
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
valid := jwt.NewJWTValidator("")
|
|
_, err = valid.ValidateToken(token.Token)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("token invalid %s %s", r.Method, r.RequestURI)
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
payload, err := jwt.GetPayload(token.Token)
|
|
if err != nil {
|
|
logger.Warn().Err(err).Msgf("token invalid %s %s", r.Method, r.RequestURI)
|
|
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
|
|
err = logAPICall(payload, r.RequestURI, r.Method, apiCalls)
|
|
if err != nil {
|
|
logger.Warn().Msgf("Call log %s %s '%v'", r.Method, r.RequestURI, err)
|
|
}
|
|
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, "identity", payload)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
return wrapper
|
|
}
|
|
|
|
func extractUserAttributes(resp *http.Response) (map[string]interface{}, error) {
|
|
user := make(map[string]interface{})
|
|
err := json.NewDecoder(resp.Body).Decode(&user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
username, ok := user["id"]
|
|
if !ok {
|
|
return nil, errors.New("invalid user token")
|
|
}
|
|
|
|
user["username"] = username
|
|
return user, nil
|
|
}
|
|
|
|
func logAPICall(payload map[string]interface{}, uri string, method string, apiCalls queries.APICallsInterface) error {
|
|
var (
|
|
username string
|
|
ok bool
|
|
)
|
|
|
|
if username, ok = payload["username"].(string); !ok || username == "" {
|
|
return nil
|
|
}
|
|
|
|
endpoint, _, _ := strings.Cut(uri, "?")
|
|
_, err := apiCalls.Insert(common.APICall{
|
|
ClientID: username,
|
|
AccessType: common.AccessTypeJWT,
|
|
Endpoint: endpoint,
|
|
Method: method,
|
|
})
|
|
|
|
return err
|
|
}
|