Initial cloud-services repo - gateway service + pkg modules
This commit is contained in:
151
pkg/httphandlers/auth_apitoken.go
Normal file
151
pkg/httphandlers/auth_apitoken.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"fiskerinc.com/modules/adminroles"
|
||||
"fiskerinc.com/modules/cache"
|
||||
"fiskerinc.com/modules/common/authproviders"
|
||||
c "fiskerinc.com/modules/common/context"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/jwt"
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/utils"
|
||||
)
|
||||
|
||||
var ErrorNoAPITokenHeader = errors.New("no api token header")
|
||||
|
||||
type AuthAPIToken struct {
|
||||
APITokens queries.APITokensInterface
|
||||
APICalls queries.APICallsInterface
|
||||
JWTAuth bool
|
||||
GroupKey string
|
||||
authJWT *AuthJWTToken
|
||||
cache *cache.APITokenCache
|
||||
onceAuthJTW sync.Once
|
||||
onceCache sync.Once
|
||||
AuthBase
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) GetHandler(requiredRoles map[string][]adminroles.RoleID, next http.HandlerFunc) http.HandlerFunc {
|
||||
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
var ctx context.Context
|
||||
|
||||
ctx, err = a.Check(requiredRoles, r)
|
||||
if errors.Is(err, ErrorNoAPITokenHeader) && a.JWTAuth {
|
||||
ctx, err = a.GetJWTAuth().Check(requiredRoles, r)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Warn().Msgf("AuthAPIToken %s %s '%v'", r.Method, r.RequestURI, err)
|
||||
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) Check(requiredRoles map[string][]adminroles.RoleID, r *http.Request) (context.Context, error) {
|
||||
// if there are no required roles, anyone can access
|
||||
if !a.hasRoles(requiredRoles) {
|
||||
return r.Context(), nil
|
||||
}
|
||||
|
||||
token, ok := a.HasAPIToken(r)
|
||||
if !ok {
|
||||
return nil, ErrorNoAPITokenHeader
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), ClientIDContextKey, token)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// API Token is hard coded as provider
|
||||
r = utils.AUTHWriteProviderToRequest(authproviders.FiskerAPIKey, r)
|
||||
roles, ok := a.getRolesForProvider(authproviders.FiskerAPIKey, requiredRoles)
|
||||
if !ok {
|
||||
return nil, errors.New(adminroles.MissingPermissionError)
|
||||
}
|
||||
|
||||
err := a.checkAPIToken(roles, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = a.addContext(r.Context(), c.ProviderKey, authproviders.FiskerAPIKey)
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) HasAPIToken(r *http.Request) (string, bool) {
|
||||
token, ok := r.Header[cache.ApiKeyHeader]
|
||||
if !ok || len(token[0]) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return token[0], true
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) checkAPIToken(requiredRoles []adminroles.RoleID, token string) error {
|
||||
if len(requiredRoles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
roles, err := a.APICache().Get(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
checker := adminroles.RolesChecker{}
|
||||
checker.SetRequiredRoles(requiredRoles)
|
||||
|
||||
return checker.Check(strings.Split(roles, ","))
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) GetJWTAuth() *AuthJWTToken {
|
||||
a.onceAuthJTW.Do(func() {
|
||||
if a.authJWT == nil {
|
||||
a.authJWT = &AuthJWTToken{
|
||||
groupkey: a.GroupKey,
|
||||
apiCalls: a.APICalls,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return a.authJWT
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) SetGroupKey(gp string) {
|
||||
a.authJWT.SetGroupKey(gp)
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) APICache() *cache.APITokenCache {
|
||||
a.onceCache.Do(func() {
|
||||
if a.cache == nil {
|
||||
a.cache = &cache.APITokenCache{
|
||||
APITokens: a.APITokens,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return a.cache
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) GetValidator() jwt.JWTValidatorInterface {
|
||||
return a.GetJWTAuth().GetValidator()
|
||||
}
|
||||
|
||||
func (a *AuthAPIToken) Close() {
|
||||
a.APITokens = nil
|
||||
a.authJWT = nil
|
||||
a.cache = nil
|
||||
}
|
||||
243
pkg/httphandlers/auth_apitoken_test.go
Normal file
243
pkg/httphandlers/auth_apitoken_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package httphandlers_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"fiskerinc.com/modules/adminroles"
|
||||
"fiskerinc.com/modules/cache"
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/common/authproviders"
|
||||
c "fiskerinc.com/modules/common/context"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/db/queries/mocks"
|
||||
"fiskerinc.com/modules/httphandlers"
|
||||
"fiskerinc.com/modules/redis"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
type testCaseAuthAPIToken struct {
|
||||
RedisClient redis.Client
|
||||
Query queries.APITokensInterface
|
||||
APICalls queries.APICallsInterface
|
||||
RequiredRoles map[string][]adminroles.RoleID
|
||||
JWTAuth bool
|
||||
ExpectedProvider string
|
||||
testhelper.BasicHttpTest
|
||||
}
|
||||
|
||||
func TestAuthAPIToken(t *testing.T) {
|
||||
adminroles.RoleCreate = adminroles.RoleID("efcc3025-e2d8-4212-8227-805c7be39d2c")
|
||||
adminroles.RoleDelete = adminroles.RoleID("8f78dce7-f5f9-4033-a10c-c9c7408bfcfe")
|
||||
|
||||
someErr := errors.New("some err")
|
||||
validExpiresAt := time.Now().Add(time.Hour)
|
||||
client := &RedisMockAuthAPIToken{
|
||||
GetCacheError: redis.ErrNilObject,
|
||||
}
|
||||
db := &mocks.MockAPITokens{
|
||||
DBMockHelper: mocks.DBMockHelper{
|
||||
Error: errors.New("token not found"),
|
||||
},
|
||||
}
|
||||
dbCalls := &mocks.MockAPICalls{
|
||||
DBMockHelper: mocks.DBMockHelper{Error: nil},
|
||||
}
|
||||
apiToken := "XXXXXXXXXXXX"
|
||||
req := testhelper.MakeTestRequestWithHeaders(http.MethodGet, "/", map[string]string{cache.ApiKeyHeader: apiToken}, nil)
|
||||
roles := map[string][]adminroles.RoleID{
|
||||
authproviders.Default: {adminroles.RoleCreate},
|
||||
}
|
||||
|
||||
tests := []testCaseAuthAPIToken{
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Good API Token, no required permission",
|
||||
Request: req,
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `OK`,
|
||||
},
|
||||
RequiredRoles: nil,
|
||||
RedisClient: client,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Good API Token",
|
||||
Request: req,
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `OK`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: &mocks.MockAPITokens{
|
||||
GetResult: &common.APIToken{
|
||||
Token: apiToken,
|
||||
Roles: strings.Join([]string{string(adminroles.RoleCreate)}, ","),
|
||||
},
|
||||
},
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Good API Token with expiration",
|
||||
Request: req,
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `OK`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: &mocks.MockAPITokens{
|
||||
GetResult: &common.APIToken{
|
||||
Token: apiToken,
|
||||
Roles: strings.Join([]string{string(adminroles.RoleCreate)}, ","),
|
||||
ExpiresAt: &validExpiresAt,
|
||||
},
|
||||
},
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Good API Token, without permission",
|
||||
Request: req,
|
||||
ExpectedStatus: http.StatusUnauthorized,
|
||||
ExpectedResponse: `{"message":"missing permission","error":"Unauthorized"}`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: &mocks.MockAPITokens{
|
||||
GetResult: &common.APIToken{
|
||||
Token: apiToken,
|
||||
Roles: strings.Join([]string{string(adminroles.RoleDelete)}, ","),
|
||||
},
|
||||
},
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Bad API Token",
|
||||
Request: req,
|
||||
ExpectedStatus: http.StatusUnauthorized,
|
||||
ExpectedResponse: `{"message":"token not found","error":"Unauthorized"}`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: db,
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Unknown API Token",
|
||||
Request: testhelper.MakeTestRequestWithHeaders(http.MethodGet, "/", map[string]string{cache.ApiKeyHeader: "abc"}, nil),
|
||||
ExpectedStatus: http.StatusUnauthorized,
|
||||
ExpectedResponse: `{"message":"token not found","error":"Unauthorized"}`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: db,
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "No headers",
|
||||
Request: testhelper.MakeTestRequest(http.MethodGet, "/", nil),
|
||||
ExpectedStatus: http.StatusUnauthorized,
|
||||
ExpectedResponse: `{"message":"no api token header","error":"Unauthorized"}`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: db,
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "No headers, JWT auth",
|
||||
Request: testhelper.MakeTestRequest(http.MethodGet, "/", nil),
|
||||
ExpectedStatus: http.StatusUnauthorized,
|
||||
ExpectedResponse: `{"message":"no authorization header","error":"Unauthorized"}`,
|
||||
},
|
||||
JWTAuth: true,
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: db,
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "No headers, JWT auth, no required roles",
|
||||
Request: testhelper.MakeTestRequest(http.MethodGet, "/", nil),
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `OK`,
|
||||
},
|
||||
JWTAuth: true,
|
||||
RedisClient: client,
|
||||
Query: db,
|
||||
APICalls: dbCalls,
|
||||
},
|
||||
{
|
||||
BasicHttpTest: testhelper.BasicHttpTest{
|
||||
Name: "Failed api calls log",
|
||||
Request: req,
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `OK`,
|
||||
},
|
||||
RequiredRoles: roles,
|
||||
RedisClient: client,
|
||||
Query: &mocks.MockAPITokens{
|
||||
GetResult: &common.APIToken{
|
||||
Token: apiToken,
|
||||
Roles: strings.Join([]string{string(adminroles.RoleCreate)}, ","),
|
||||
ExpiresAt: &validExpiresAt,
|
||||
},
|
||||
},
|
||||
APICalls: &mocks.MockAPICalls{
|
||||
DBMockHelper: mocks.DBMockHelper{
|
||||
Error: someErr,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
handler := setupAuthAPIToken(t, test)
|
||||
testhelper.RunBasicHttpTest(t, test.BasicHttpTest, handler)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupAuthAPIToken(t *testing.T, test testCaseAuthAPIToken) http.HandlerFunc {
|
||||
testHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
if test.ExpectedProvider != "" {
|
||||
if provider, ok := r.Context().Value(c.ProviderKey).(string); !ok || provider != test.ExpectedProvider {
|
||||
t.Errorf(testhelper.TestErrorTemplate, test.Name, test.ExpectedProvider, provider)
|
||||
}
|
||||
}
|
||||
|
||||
w.Write([]byte(expectedOkBody))
|
||||
}
|
||||
|
||||
auth := httphandlers.AuthAPIToken{
|
||||
APITokens: test.Query,
|
||||
APICalls: test.APICalls,
|
||||
JWTAuth: test.JWTAuth,
|
||||
}
|
||||
|
||||
return auth.GetHandler(test.RequiredRoles, testHandler)
|
||||
}
|
||||
|
||||
type RedisMockAuthAPIToken struct {
|
||||
GetCacheError error
|
||||
SetCacheError error
|
||||
redis.Connection
|
||||
}
|
||||
|
||||
func (m *RedisMockAuthAPIToken) GetCache(key string, dest interface{}, expire int) error {
|
||||
return m.GetCacheError
|
||||
}
|
||||
|
||||
func (m *RedisMockAuthAPIToken) SetCache(string, interface{}, int) error {
|
||||
return m.SetCacheError
|
||||
}
|
||||
39
pkg/httphandlers/auth_base.go
Normal file
39
pkg/httphandlers/auth_base.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"fiskerinc.com/modules/adminroles"
|
||||
"fiskerinc.com/modules/common/authproviders"
|
||||
c "fiskerinc.com/modules/common/context"
|
||||
)
|
||||
|
||||
type AuthBase struct {
|
||||
}
|
||||
|
||||
func (a AuthBase) hasRoles(requiredRoles map[string][]adminroles.RoleID) bool {
|
||||
if len(requiredRoles) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, roles := range requiredRoles {
|
||||
if len(roles) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (a AuthBase) getRolesForProvider(provider string, required map[string][]adminroles.RoleID) (roles []adminroles.RoleID, ok bool) {
|
||||
if roles, ok = required[provider]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
roles, ok = required[authproviders.Default]
|
||||
return
|
||||
}
|
||||
|
||||
func (a AuthBase) addContext(ctx context.Context, key c.ContextType, value string) context.Context {
|
||||
return context.WithValue(ctx, key, value)
|
||||
}
|
||||
166
pkg/httphandlers/auth_jwttoken.go
Normal file
166
pkg/httphandlers/auth_jwttoken.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"fiskerinc.com/modules/adminroles"
|
||||
"fiskerinc.com/modules/common/authproviders"
|
||||
c "fiskerinc.com/modules/common/context"
|
||||
"fiskerinc.com/modules/db/queries"
|
||||
"fiskerinc.com/modules/jwt"
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/utils"
|
||||
)
|
||||
|
||||
var errNoUsername = errors.New("no username")
|
||||
|
||||
type AuthCheckerInterface interface {
|
||||
GetHandler(requiredRoles map[string][]adminroles.RoleID, next http.HandlerFunc) http.HandlerFunc
|
||||
Check(requiredRoles map[string][]adminroles.RoleID, r *http.Request) (context.Context, error)
|
||||
GetValidator() jwt.JWTValidatorInterface
|
||||
SetGroupKey(string)
|
||||
Close()
|
||||
}
|
||||
|
||||
type AuthJWTToken struct {
|
||||
groupkey string
|
||||
apiCalls queries.APICallsInterface
|
||||
validator jwt.JWTValidatorInterface
|
||||
AuthBase
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) GetValidator() jwt.JWTValidatorInterface {
|
||||
if a.validator == nil {
|
||||
a.validator = jwt.NewJWTValidator("")
|
||||
}
|
||||
return a.validator
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) GroupKey() string {
|
||||
if len(a.groupkey) == 0 {
|
||||
return "custom:groups"
|
||||
}
|
||||
|
||||
return a.groupkey
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) SetGroupKey(gp string) {
|
||||
a.groupkey = gp
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) GetHandler(requiredRoles map[string][]adminroles.RoleID, next http.HandlerFunc) http.HandlerFunc {
|
||||
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, err := a.Check(requiredRoles, r)
|
||||
if err != nil {
|
||||
logger.Warn().Msgf("AuthJWTToken %s %s '%v'", r.Method, r.RequestURI, err)
|
||||
utils.RespError(w, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) Check(requiredRoles map[string][]adminroles.RoleID, r *http.Request) (context.Context, error) {
|
||||
// if there are no required roles, anyone can access
|
||||
if !a.hasRoles(requiredRoles) {
|
||||
return r.Context(), nil
|
||||
}
|
||||
|
||||
token, err := jwt.GetAuthorizationHeader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, err := a.GetValidator().ValidateToken(token.Token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if username, ok := a.getUsername(payload); ok && username != "" {
|
||||
ctx := context.WithValue(r.Context(), ClientIDContextKey, username)
|
||||
*r = *r.WithContext(ctx)
|
||||
} else {
|
||||
return nil, errNoUsername
|
||||
}
|
||||
|
||||
provider := a.getProvider(payload)
|
||||
roles, ok := a.getRolesForProvider(provider, requiredRoles)
|
||||
if !ok {
|
||||
return nil, errors.New(adminroles.MissingPermissionError)
|
||||
}
|
||||
|
||||
checker := adminroles.RolesChecker{}
|
||||
checker.SetRequiredRoles(roles)
|
||||
err = checker.CheckGroups(payload[a.GroupKey()])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := a.addContext(r.Context(), c.ProviderKey, provider)
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) getProvider(payload map[string]interface{}) string {
|
||||
var ok bool
|
||||
var data interface{}
|
||||
var identities []interface{}
|
||||
var identity map[string]interface{}
|
||||
var provider string
|
||||
|
||||
if data, ok = payload["identities"]; !ok {
|
||||
return authproviders.Default
|
||||
}
|
||||
|
||||
if identities, ok = data.([]interface{}); !ok || len(identities) != 1 {
|
||||
return authproviders.Default
|
||||
}
|
||||
|
||||
if identity, ok = identities[0].(map[string]interface{}); !ok {
|
||||
return authproviders.Default
|
||||
}
|
||||
|
||||
if provider, ok = identity["providerName"].(string); !ok {
|
||||
return authproviders.Default
|
||||
}
|
||||
|
||||
return provider
|
||||
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) Close() {
|
||||
// nothing to dispose here
|
||||
}
|
||||
|
||||
func (a *AuthJWTToken) getUsername(payload map[string]interface{}) (string, bool) {
|
||||
username, ok := payload["cognito:username"].(string)
|
||||
if ok && username != "" {
|
||||
return username, true
|
||||
}
|
||||
|
||||
username, ok = payload["username"].(string)
|
||||
if ok && username != "" {
|
||||
return username, true
|
||||
}
|
||||
|
||||
username, ok = payload["preferred_username"].(string)
|
||||
if ok && username != "" {
|
||||
return username, true
|
||||
}
|
||||
|
||||
username, ok = payload["email"].(string)
|
||||
if ok && username != "" {
|
||||
return username, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
262
pkg/httphandlers/auth_jwttoken_test.go
Normal file
262
pkg/httphandlers/auth_jwttoken_test.go
Normal file
File diff suppressed because one or more lines are too long
21
pkg/httphandlers/base_url_handler.go
Normal file
21
pkg/httphandlers/base_url_handler.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"fiskerinc.com/modules/utils/envtool"
|
||||
)
|
||||
|
||||
// ServiceBaseURL base url of service i.e. "/service"
|
||||
var ServiceBaseURL = envtool.GetEnv("SERVICE_BASE_URL", "")
|
||||
|
||||
// HandleBaseURL appends base url to path
|
||||
func HandleBaseURL(path string, fn http.HandlerFunc) (string, http.HandlerFunc) {
|
||||
return strings.Join([]string{ServiceBaseURL, path}, ""), fn
|
||||
}
|
||||
|
||||
// HttpRouterHandleBaseURL appends base url to path
|
||||
func HttpRouterHandleBaseURL(path string) string {
|
||||
return strings.Join([]string{ServiceBaseURL, path}, "")
|
||||
}
|
||||
25
pkg/httphandlers/base_url_handler_test.go
Normal file
25
pkg/httphandlers/base_url_handler_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestHandleBaseURL(t *testing.T) {
|
||||
expectedPath := "/base/endpoint"
|
||||
testHandler := func(w http.ResponseWriter, r *http.Request) {}
|
||||
ServiceBaseURL = "/base"
|
||||
|
||||
path, handler := HandleBaseURL("/endpoint", testHandler)
|
||||
|
||||
if path != expectedPath {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Path", expectedPath, path)
|
||||
}
|
||||
|
||||
if reflect.ValueOf(handler).Pointer() != reflect.ValueOf(testHandler).Pointer() {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Handler", reflect.ValueOf(testHandler).Pointer(), reflect.ValueOf(handler).Pointer())
|
||||
}
|
||||
}
|
||||
19
pkg/httphandlers/context.go
Normal file
19
pkg/httphandlers/context.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"fiskerinc.com/modules/common/context"
|
||||
)
|
||||
|
||||
const ClientIDContextKey context.ContextType = "client_id"
|
||||
|
||||
func GetClientID(r *http.Request) string {
|
||||
v := r.Context().Value(ClientIDContextKey)
|
||||
id, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
||||
37
pkg/httphandlers/cors_handler.go
Normal file
37
pkg/httphandlers/cors_handler.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package httphandlers
|
||||
|
||||
import "net/http"
|
||||
|
||||
// CORSHandler middleware to add CORS headers. USED FOR LOCAL DEVELOPMENT.
|
||||
func CORSHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
||||
header := w.Header()
|
||||
|
||||
header.Set("Access-Control-Allow-Credentials", "true")
|
||||
header.Set("Access-Control-Allow-Headers", "*")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Set("Access-Control-Allow-Methods", "*")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func HttpRouterCORSHandler() http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := w.Header()
|
||||
header.Set("Access-Control-Allow-Credentials", "true")
|
||||
header.Set("Access-Control-Allow-Headers", "*")
|
||||
header.Set("Access-Control-Allow-Origin", "*")
|
||||
header.Set("Access-Control-Allow-Methods", "*")
|
||||
|
||||
// Adjust status code to 204
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
21
pkg/httphandlers/log_request.go
Normal file
21
pkg/httphandlers/log_request.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
)
|
||||
|
||||
func LogRequest(next http.HandlerFunc) http.HandlerFunc {
|
||||
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info().
|
||||
Str("headers", fmt.Sprintf("%v", r.Header)).
|
||||
Str("ip", r.RemoteAddr).
|
||||
Str("user", GetClientID(r)).
|
||||
Msgf("%s %s", r.Method, r.RequestURI)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
42
pkg/httphandlers/log_request_test.go
Normal file
42
pkg/httphandlers/log_request_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package httphandlers_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
"fiskerinc.com/modules/httphandlers"
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestLogRequest(t *testing.T) {
|
||||
tests := []testhelper.BasicHttpTest{
|
||||
{
|
||||
Name: "GET request",
|
||||
Request: testhelper.MakeTestRequest(http.MethodGet, "/test", nil),
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `null`,
|
||||
},
|
||||
{
|
||||
Name: "POST request",
|
||||
Request: testhelper.MakeTestRequest(http.MethodPost, "/test", common.Car{VIN: "TESTVIN"}),
|
||||
ExpectedStatus: http.StatusOK,
|
||||
ExpectedResponse: `{"vin":"TESTVIN"}`,
|
||||
},
|
||||
}
|
||||
|
||||
echo := func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
} else {
|
||||
w.Write(b)
|
||||
}
|
||||
}
|
||||
handler := httphandlers.LogRequest(echo)
|
||||
|
||||
for _, test := range tests {
|
||||
testhelper.RunBasicHttpTest(t, test, handler)
|
||||
}
|
||||
}
|
||||
24
pkg/httphandlers/method_checker.go
Normal file
24
pkg/httphandlers/method_checker.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"fiskerinc.com/modules/utils"
|
||||
)
|
||||
|
||||
// MethodAll to handle all http method
|
||||
const MethodAll = "*"
|
||||
|
||||
// CheckMethod middleware to enforce method
|
||||
func CheckMethod(method string, next http.HandlerFunc) http.HandlerFunc {
|
||||
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
||||
if method != r.Method && method != MethodAll {
|
||||
utils.RespError(w, http.StatusBadRequest, fmt.Sprintf("Not %s method", method))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
35
pkg/httphandlers/method_checker_test.go
Normal file
35
pkg/httphandlers/method_checker_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestCheckMethod(t *testing.T) {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
testCheckMethod(t, CheckMethod(http.MethodGet, handler), http.StatusOK)
|
||||
testCheckMethod(t, CheckMethod(MethodAll, handler), http.StatusOK)
|
||||
testCheckMethod(t, CheckMethod(http.MethodPost, handler), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func testCheckMethod(t *testing.T, fn http.HandlerFunc, expectedStatusCode int) {
|
||||
req := setupAuthorizeRequest()
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
fn(recorder, req)
|
||||
|
||||
if recorder.Code != expectedStatusCode {
|
||||
t.Errorf(testhelper.TestErrorTemplate, "Status code", expectedStatusCode, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func setupAuthorizeRequest() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
return req
|
||||
}
|
||||
32
pkg/httphandlers/panic_http_handler.go
Normal file
32
pkg/httphandlers/panic_http_handler.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/utils"
|
||||
)
|
||||
|
||||
// PanicHandler Panic handler wrapper for http handlers
|
||||
func PanicHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
wrapper := func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug().Msgf("%s %s", r.Method, r.RequestURI)
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
HttpRouterPanicHandler(w, r, err)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func HttpRouterPanicHandler(w http.ResponseWriter, r *http.Request, p interface{}) {
|
||||
logger.Error().Msgf("PanicHandler %v %s", p, string(debug.Stack()))
|
||||
utils.RespError(w, http.StatusInternalServerError, fmt.Sprintf("PanicHandler %v", p))
|
||||
}
|
||||
50
pkg/httphandlers/panic_http_handler_test.go
Normal file
50
pkg/httphandlers/panic_http_handler_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func checkPanic(w http.ResponseWriter, r *http.Request) {
|
||||
panic("Test panic")
|
||||
}
|
||||
|
||||
func noPanic(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func TestPanicHandler(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
httpHandler http.HandlerFunc
|
||||
status int
|
||||
body string
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{name: "Panic", httpHandler: checkPanic, status: http.StatusInternalServerError, body: `{"message":"PanicHandler Test panic","error":"Internal Server Error"}`},
|
||||
{name: "No Panic", httpHandler: noPanic, status: http.StatusOK, body: "OK"},
|
||||
}
|
||||
|
||||
for _, i := range tests {
|
||||
request, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
response := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(PanicHandler(i.httpHandler))
|
||||
handler.ServeHTTP(response, request)
|
||||
|
||||
if response.Result().StatusCode != i.status {
|
||||
t.Errorf(testhelper.TestErrorTemplate, i.name, i.status, response.Result().Status)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(response.Body.String()) != i.body {
|
||||
t.Errorf(testhelper.TestErrorTemplate, i.name, i.body, response.Body)
|
||||
}
|
||||
}
|
||||
}
|
||||
28
pkg/httphandlers/parserequest_handler.go
Normal file
28
pkg/httphandlers/parserequest_handler.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
|
||||
"fiskerinc.com/modules/validator"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func ParseRequest(r *http.Request, data interface{}) error {
|
||||
err := json.NewDecoder(r.Body).Decode(data)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return errors.Wrapf(validator.ValidateStruct(data), "request %v", data)
|
||||
}
|
||||
|
||||
func ParseXMLRequest(r *http.Request, data interface{}) error {
|
||||
err := xml.NewDecoder(r.Body).Decode(data)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return errors.Wrapf(validator.ValidateStruct(data), "request %v", data)
|
||||
}
|
||||
24
pkg/httphandlers/parserequest_handler_test.go
Normal file
24
pkg/httphandlers/parserequest_handler_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package httphandlers_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/common"
|
||||
h "fiskerinc.com/modules/httphandlers"
|
||||
th "fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
func TestParseRequestHandler(t *testing.T) {
|
||||
req := th.MakeTestRequest(http.MethodPost, "http://test.com", common.Car{
|
||||
VIN: "1G1FP87S3GN100062",
|
||||
Model: "Ocean",
|
||||
Year: 2022,
|
||||
})
|
||||
var car common.Car
|
||||
err := h.ParseRequest(req, &car)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
}
|
||||
85
pkg/httphandlers/swagger_docs_handler.go
Normal file
85
pkg/httphandlers/swagger_docs_handler.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
httpSwagger "github.com/swaggo/http-swagger"
|
||||
|
||||
"fiskerinc.com/modules/logger"
|
||||
"fiskerinc.com/modules/utils/envtool"
|
||||
)
|
||||
|
||||
var swaggerHandler func(http.ResponseWriter, *http.Request)
|
||||
var username = envtool.GetEnv("SWAGGER_USERNAME", "")
|
||||
var password = envtool.GetEnv("SWAGGER_PASSWORD", "")
|
||||
|
||||
func GetSwaggerHandler() http.HandlerFunc {
|
||||
swaggerHandler = httpSwagger.Handler()
|
||||
|
||||
return swaggerDocs
|
||||
}
|
||||
|
||||
func swaggerDocs(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info().Msgf("SwaggerDocs %s %s", r.Method, r.RequestURI)
|
||||
|
||||
if username != "" && password != "" && !basicAuth(username, password, w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
if redirectToDocs(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
swaggerHandler(w, r)
|
||||
}
|
||||
|
||||
func basicAuth(username, password string, w http.ResponseWriter, r *http.Request) bool {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("401 Unauthorized\n"))
|
||||
return false
|
||||
}
|
||||
|
||||
credentials, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("400 Bad Request\n"))
|
||||
return false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(string(credentials), ":", 2)
|
||||
if len(parts) != 2 || parts[0] != username || parts[1] != password {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("401 Unauthorized\n"))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func redirectToDocs(w http.ResponseWriter, r *http.Request) bool {
|
||||
url := shouldRedirect(r)
|
||||
if url == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
http.Redirect(w, r, url, http.StatusMovedPermanently)
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldRedirect(r *http.Request) string {
|
||||
u, err := url.Parse(r.RequestURI)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
}
|
||||
if strings.HasSuffix(u.Path, ".html") || strings.HasSuffix(u.Path, ".json") || strings.HasSuffix(u.Path, ".js") || strings.HasSuffix(u.Path, ".css") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return path.Join(u.Path, "index.html")
|
||||
}
|
||||
66
pkg/httphandlers/swagger_docs_handler_test.go
Normal file
66
pkg/httphandlers/swagger_docs_handler_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package httphandlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"fiskerinc.com/modules/testhelper"
|
||||
)
|
||||
|
||||
const expectedRedirectURL = "/docs/index.html"
|
||||
const expectedNoRedirectURL = ""
|
||||
|
||||
func TestSwaggerDocsRedirect(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
request *http.Request
|
||||
expectedRedirect string
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Redirect to index.html",
|
||||
request: makeRequest("http://test.com/docs"),
|
||||
expectedRedirect: expectedRedirectURL,
|
||||
},
|
||||
{
|
||||
name: "/ Redirect to index.html",
|
||||
request: makeRequest("http://test.com/docs/"),
|
||||
expectedRedirect: expectedRedirectURL,
|
||||
},
|
||||
{
|
||||
name: "Requests index.html, no redirect",
|
||||
request: makeRequest("http://test.com/docs/index.html"),
|
||||
expectedRedirect: expectedNoRedirectURL,
|
||||
},
|
||||
{
|
||||
name: "Requests .js, no redirect",
|
||||
request: makeRequest("http://test.com/docs/index.js"),
|
||||
expectedRedirect: expectedNoRedirectURL,
|
||||
},
|
||||
{
|
||||
name: "Requests .css, no redirect",
|
||||
request: makeRequest("http://test.com/docs/index.css"),
|
||||
expectedRedirect: expectedNoRedirectURL,
|
||||
},
|
||||
{
|
||||
name: "Requests .json, no redirect",
|
||||
request: makeRequest("http://test.com/docs/index.json"),
|
||||
expectedRedirect: expectedNoRedirectURL,
|
||||
},
|
||||
}
|
||||
|
||||
for _, item := range tests {
|
||||
path := shouldRedirect(item.request)
|
||||
|
||||
if path != item.expectedRedirect {
|
||||
t.Errorf(testhelper.TestErrorTemplate, item.name, item.expectedRedirect, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeRequest(url string) *http.Request {
|
||||
request, _ := http.NewRequest(http.MethodGet, url, nil)
|
||||
request.RequestURI = url
|
||||
return request
|
||||
}
|
||||
Reference in New Issue
Block a user