package http_middleware import ( "bitbucket.org/BlackBroker/heruvym/jwt_adapter" "bitbucket.org/BlackBroker/heruvym/middleware/hijack" "context" "fmt" "github.com/skeris/authService/errors" "net/http" "runtime/debug" "strings" "go.uber.org/zap" ) type MiddlewareFunc func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) type AfterFunc func(ctx context.Context, r *http.Request) (context.Context, error) type SetCookieValueFunc func(w http.ResponseWriter, value string) func DefaultChain( log *zap.Logger, recFn RecoverFunc, afn AfterFunc, setCookieValue SetCookieValueFunc, mws ...MiddlewareFunc, ) http.HandlerFunc { return Chain( append( []MiddlewareFunc{ MiddlewareRecovery(log, recFn), MiddlewareLogger(log), DefaultCookieAndRecoveryMiddleware( log, recFn, afn, setCookieValue, ), }, mws..., )..., ) } func DefaultCookieAndRecoveryMiddleware( log *zap.Logger, recFn RecoverFunc, afn AfterFunc, setCookieValue SetCookieValueFunc, ) MiddlewareFunc { const headerKey = jwt_adapter.DefaultHeaderKey return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { var err error cookie := &jwt_adapter.JwtAdapter{} tokenHeader := r.Header.Get(headerKey) if tokenHeader == "" { fmt.Println("ERROR NO authHEader") cookie.Init() } else { splitted := strings.Split(tokenHeader, " ") if len(splitted) != 2 { w.WriteHeader(http.StatusForbidden) w.Header().Add("Content-Type", "application/json") return } tokenPart := splitted[1] cookie, err = jwt_adapter.Decode(tokenPart) if err != nil { cookie.Init() } } recovery := struct { val interface{} trace string }{} ctx := context.WithValue(r.Context(), headerKey, cookie) if afn != nil { c, err := afn(ctx, r) if err != nil { panic(err) } ctx = c } w, commit := hijack.New(w, func(w http.ResponseWriter) { cookie.LastSeen = jwt_adapter.Timestamp() if val, err := cookie.Encode(); err != nil { panic(err) } else { setCookieValue(w, val) } if recovery.val != nil { log.Error("handler recovered", zap.Any("recovered", recovery.val)) code, message := recFn(recovery.val, ctx) w.WriteHeader(code) if _, err := fmt.Fprint(w, message); err != nil { log.Error("error writing panic response", zap.Error(err)) } } }) defer func() { if rec := recover(); rec != nil { recovery.val = rec recovery.trace = string(debug.Stack()) } commit() }() next(w, r.WithContext(ctx)) } } func Chain(mws ...MiddlewareFunc) http.HandlerFunc { if len(mws) == 0 { return func(w http.ResponseWriter, r *http.Request) {} } h := link(mws[len(mws)-1], nil) for i := len(mws) - 2; i >= 0; i-- { mw := mws[i] h = link(mw, h) } return h } func link(mw MiddlewareFunc, next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { mw(w, r, next) } } type RecoverFunc func(rec interface{}, ctx context.Context) (code int, message string) func MiddlewareRecovery(log *zap.Logger, recFn RecoverFunc) MiddlewareFunc { return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { defer func() { if rec := recover(); rec != nil { code, message := recFn(rec, r.Context()) w.WriteHeader(code) if _, err := fmt.Fprint(w, message); err != nil { log.Error("error writing panic response", zap.Error(err)) } log.Error("panic in http handler", zap.Int("code", code), zap.String("message", message), zap.Any("recovered", rec)) } }() next(w, r) } } func MiddlewareLogger(log *zap.Logger) MiddlewareFunc { return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { log.Debug("http request", zap.String("url", r.URL.String())) next(w, r) } } func Handler(h http.HandlerFunc) MiddlewareFunc { return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { h.ServeHTTP(w, r) if next != nil { next(w, r) } } } func Wrap(mux *http.ServeMux, logger *zap.Logger) http.HandlerFunc { return DefaultChain( logger, func(rec interface{}, ctx context.Context) (int, string) { var ( code int message string ) if err, ok := rec.(error); ok { if v, ok := errors.IsForbidden(err); ok { code = http.StatusForbidden message = v.Error() } else if v, ok := errors.IsUnauthenticated(err); ok { code = http.StatusUnauthorized message = v.Error() } else { code = http.StatusInternalServerError message = err.Error() } } else { code = http.StatusInternalServerError message = fmt.Sprintf("%v", rec) } return code, message }, nil, func(w http.ResponseWriter, val string) { var setFn func(key, value string) if w.Header().Get(jwt_adapter.DefaultHeaderKey) == "" { setFn = w.Header().Add } else { setFn = w.Header().Set } setFn(jwt_adapter.DefaultHeaderKey, fmt.Sprintf("Bearer %s", val)) }, Handler((func() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { //todo specify origins w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Headers", "*") w.Header().Set("Access-Control-Expose-Headers", "*") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "0") mux.ServeHTTP(w, r) } })())) }