heruvym/middleware/http_middleware.go

243 lines
5.6 KiB
Go
Raw Normal View History

2021-04-11 09:48:15 +00:00
package http_middleware
import (
"bitbucket.org/BlackBroker/heruvym/jwt_adapter"
"bitbucket.org/BlackBroker/heruvym/middleware/hijack"
"context"
"fmt"
"github.com/skeris/authService/errors"
"net/http"
"runtime/debug"
"strings"
"go.uber.org/zap"
)
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc)
type AfterFunc func(ctx context.Context, r *http.Request) (context.Context, error)
type SetCookieValueFunc func(w http.ResponseWriter, value string)
func DefaultChain(
log *zap.Logger,
recFn RecoverFunc,
afn AfterFunc,
setCookieValue SetCookieValueFunc,
mws ...MiddlewareFunc,
) http.HandlerFunc {
return Chain(
append(
[]MiddlewareFunc{
MiddlewareRecovery(log, recFn),
MiddlewareLogger(log),
DefaultCookieAndRecoveryMiddleware(
log,
recFn, afn, setCookieValue,
),
},
mws...,
)...,
)
}
func DefaultCookieAndRecoveryMiddleware(
log *zap.Logger,
recFn RecoverFunc,
afn AfterFunc,
setCookieValue SetCookieValueFunc,
) MiddlewareFunc {
const headerKey = jwt_adapter.DefaultHeaderKey
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
2021-05-01 10:05:45 +00:00
var (
err error
tokenHeader string
)
2021-04-11 09:48:15 +00:00
cookie := &jwt_adapter.JwtAdapter{}
2021-05-01 10:05:45 +00:00
if r.Method == http.MethodGet {
tokenHeader = fmt.Sprintf("Bearer %s", r.Form.Get(headerKey))
} else {
tokenHeader = r.Header.Get(headerKey)
}
2021-04-11 09:48:15 +00:00
if tokenHeader == "" {
fmt.Println("ERROR NO authHEader")
cookie.Init()
} else {
splitted := strings.Split(tokenHeader, " ")
if len(splitted) != 2 {
w.WriteHeader(http.StatusForbidden)
w.Header().Add("Content-Type", "application/json")
return
}
tokenPart := splitted[1]
cookie, err = jwt_adapter.Decode(tokenPart)
if err != nil {
cookie.Init()
}
}
recovery := struct {
val interface{}
trace string
}{}
ctx := context.WithValue(r.Context(), headerKey, cookie)
if afn != nil {
c, err := afn(ctx, r)
if err != nil {
panic(err)
}
ctx = c
}
w, commit := hijack.New(w, func(w http.ResponseWriter) {
cookie.LastSeen = jwt_adapter.Timestamp()
if val, err := cookie.Encode(); err != nil {
panic(err)
} else {
setCookieValue(w, val)
}
if recovery.val != nil {
log.Error("handler recovered", zap.Any("recovered", recovery.val))
code, message := recFn(recovery.val, ctx)
w.WriteHeader(code)
if _, err := fmt.Fprint(w, message); err != nil {
log.Error("error writing panic response", zap.Error(err))
}
}
})
defer func() {
if rec := recover(); rec != nil {
recovery.val = rec
recovery.trace = string(debug.Stack())
}
commit()
}()
next(w, r.WithContext(ctx))
}
}
func Chain(mws ...MiddlewareFunc) http.HandlerFunc {
if len(mws) == 0 {
return func(w http.ResponseWriter, r *http.Request) {}
}
h := link(mws[len(mws)-1], nil)
for i := len(mws) - 2; i >= 0; i-- {
mw := mws[i]
h = link(mw, h)
}
return h
}
func link(mw MiddlewareFunc, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
mw(w, r, next)
}
}
type RecoverFunc func(rec interface{}, ctx context.Context) (code int, message string)
func MiddlewareRecovery(log *zap.Logger, recFn RecoverFunc) MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
defer func() {
if rec := recover(); rec != nil {
code, message := recFn(rec, r.Context())
w.WriteHeader(code)
if _, err := fmt.Fprint(w, message); err != nil {
log.Error("error writing panic response", zap.Error(err))
}
log.Error("panic in http handler", zap.Int("code", code), zap.String("message", message), zap.Any("recovered", rec))
}
}()
next(w, r)
}
}
func MiddlewareLogger(log *zap.Logger) MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
log.Debug("http request", zap.String("url", r.URL.String()))
next(w, r)
}
}
func Handler(h http.HandlerFunc) MiddlewareFunc {
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
h.ServeHTTP(w, r)
if next != nil {
next(w, r)
}
}
}
func Wrap(mux *http.ServeMux, logger *zap.Logger) http.HandlerFunc {
return DefaultChain(
logger,
func(rec interface{}, ctx context.Context) (int, string) {
var (
code int
message string
)
if err, ok := rec.(error); ok {
if v, ok := errors.IsForbidden(err); ok {
code = http.StatusForbidden
message = v.Error()
} else if v, ok := errors.IsUnauthenticated(err); ok {
code = http.StatusUnauthorized
message = v.Error()
} else {
code = http.StatusInternalServerError
message = err.Error()
}
} else {
code = http.StatusInternalServerError
message = fmt.Sprintf("%v", rec)
}
return code, message
},
nil,
func(w http.ResponseWriter, val string) {
var setFn func(key, value string)
if w.Header().Get(jwt_adapter.DefaultHeaderKey) == "" {
setFn = w.Header().Add
} else {
setFn = w.Header().Set
}
setFn(jwt_adapter.DefaultHeaderKey, fmt.Sprintf("Bearer %s", val))
},
Handler((func() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
//todo specify origins
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Expose-Headers", "*")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
mux.ServeHTTP(w, r)
}
})()))
}