120 lines
2.7 KiB
Go
120 lines
2.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/themakers/hlog"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/rs/xid"
|
|
)
|
|
|
|
type ContextKey string
|
|
|
|
const (
|
|
SessionKey = "X-SessionKey"
|
|
SessionCookie = "session"
|
|
AccountId = "id"
|
|
HlogCtxKey = "logger"
|
|
)
|
|
|
|
func AnswererChain() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
session := c.Get(SessionKey)
|
|
|
|
if session == "" {
|
|
session := xid.New().String()
|
|
c.Set(SessionKey, session)
|
|
c.Locals(ContextKey(SessionKey), session)
|
|
} else {
|
|
c.Locals(ContextKey(SessionKey), session)
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func JWTAuth() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
//todo также сделать для хуков на добавление удаление в амо
|
|
if c.Path() == "/quiz/logo" {
|
|
return c.Next()
|
|
}
|
|
authHeader := c.Get("Authorization")
|
|
if authHeader == "" {
|
|
c.Status(fiber.StatusUnauthorized).SendString("no JWT found")
|
|
return nil
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if tokenString == authHeader {
|
|
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT Header: missing Bearer")
|
|
return nil
|
|
}
|
|
|
|
publicKey := os.Getenv("PUBLIC_ACCESS_SECRET_KEY")
|
|
if publicKey == "" {
|
|
// TODO log
|
|
c.Status(fiber.StatusInternalServerError).SendString("public key not found")
|
|
return nil
|
|
}
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
return jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey))
|
|
})
|
|
if err != nil {
|
|
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT")
|
|
return nil
|
|
}
|
|
|
|
if token.Valid {
|
|
expirationTime, err := token.Claims.GetExpirationTime()
|
|
if err != nil {
|
|
c.Status(fiber.StatusUnauthorized).SendString("no expiration time in JWT")
|
|
return nil
|
|
}
|
|
if time.Now().Unix() >= expirationTime.Unix() {
|
|
c.Status(fiber.StatusUnauthorized).SendString("expired JWT")
|
|
return nil
|
|
}
|
|
} else {
|
|
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT")
|
|
return nil
|
|
}
|
|
|
|
m, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
c.Status(fiber.StatusInternalServerError).SendString("broken token claims")
|
|
return nil
|
|
}
|
|
|
|
id, ok := m["id"].(string)
|
|
if !ok || id == "" {
|
|
c.Status(fiber.StatusUnauthorized).SendString("missing id claim in JWT")
|
|
return nil
|
|
}
|
|
|
|
c.Context().SetUserValue(AccountId, id)
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func ContextLogger(logger hlog.Logger) fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
c.Locals(HlogCtxKey, logger)
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func GetAccountId(c *fiber.Ctx) (string, bool) {
|
|
id, ok := c.Context().UserValue(AccountId).(string)
|
|
return id, ok
|
|
}
|
|
|
|
func GetLogger(c *fiber.Ctx) hlog.Logger {
|
|
logger := c.Context().UserValue(HlogCtxKey).(hlog.Logger)
|
|
return logger
|
|
}
|