add mw to common

This commit is contained in:
Pavel 2024-03-13 19:41:09 +03:00
parent d59eb04dc6
commit 44b37825e8
3 changed files with 104 additions and 0 deletions

1
go.mod

@ -22,6 +22,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect

2
go.sum

@ -60,6 +60,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=

101
middleware/middleware.go Normal file

@ -0,0 +1,101 @@
package middleware
import (
"github.com/gofiber/fiber/v2"
"os"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/xid"
)
type ContextKey string
const (
SessionKey = "X-SessionKey"
SessionCookie = "session"
AccountId = "id"
)
func AnswererChain() fiber.Handler {
return func(c *fiber.Ctx) error {
session := c.Get(SessionKey)
if session == "" {
session := xid.New().String()
c.Set(SessionKey, session)
c.Locals(ContextKey(SessionKey), session)
} else {
c.Locals(ContextKey(SessionKey), session)
}
return c.Next()
}
}
func JWTAuth() fiber.Handler {
return func(c *fiber.Ctx) error {
authHeader := c.Get("Authorization")
if authHeader == "" {
c.Status(fiber.StatusUnauthorized).SendString("no JWT found")
return nil
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader {
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT Header: missing Bearer")
return nil
}
publicKey := os.Getenv("PUBLIC_ACCESS_SECRET_KEY")
if publicKey == "" {
// TODO log
c.Status(fiber.StatusInternalServerError).SendString("public key not found")
return nil
}
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey))
})
if err != nil {
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT")
return nil
}
if token.Valid {
expirationTime, err := token.Claims.GetExpirationTime()
if err != nil {
c.Status(fiber.StatusUnauthorized).SendString("no expiration time in JWT")
return nil
}
if time.Now().Unix() >= expirationTime.Unix() {
c.Status(fiber.StatusUnauthorized).SendString("expired JWT")
return nil
}
} else {
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT")
return nil
}
m, ok := token.Claims.(jwt.MapClaims)
if !ok {
c.Status(fiber.StatusInternalServerError).SendString("broken token claims")
return nil
}
id, ok := m["id"].(string)
if !ok || id == "" {
c.Status(fiber.StatusUnauthorized).SendString("missing id claim in JWT")
return nil
}
c.Context().SetUserValue(AccountId, id)
return c.Next()
}
}
func GetAccountId(c *fiber.Ctx) (string, bool) {
id, ok := c.Context().UserValue(AccountId).(string)
return id, ok
}