add mw to common
This commit is contained in:
parent
d59eb04dc6
commit
44b37825e8
1
go.mod
1
go.mod
@ -22,6 +22,7 @@ require (
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -60,6 +60,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
|
101
middleware/middleware.go
Normal file
101
middleware/middleware.go
Normal file
@ -0,0 +1,101 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
SessionKey = "X-SessionKey"
|
||||
SessionCookie = "session"
|
||||
AccountId = "id"
|
||||
)
|
||||
|
||||
func AnswererChain() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
session := c.Get(SessionKey)
|
||||
|
||||
if session == "" {
|
||||
session := xid.New().String()
|
||||
c.Set(SessionKey, session)
|
||||
c.Locals(ContextKey(SessionKey), session)
|
||||
} else {
|
||||
c.Locals(ContextKey(SessionKey), session)
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func JWTAuth() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
authHeader := c.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("no JWT found")
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == authHeader {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT Header: missing Bearer")
|
||||
return nil
|
||||
}
|
||||
|
||||
publicKey := os.Getenv("PUBLIC_ACCESS_SECRET_KEY")
|
||||
if publicKey == "" {
|
||||
// TODO log
|
||||
c.Status(fiber.StatusInternalServerError).SendString("public key not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
return jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey))
|
||||
})
|
||||
if err != nil {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT")
|
||||
return nil
|
||||
}
|
||||
|
||||
if token.Valid {
|
||||
expirationTime, err := token.Claims.GetExpirationTime()
|
||||
if err != nil {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("no expiration time in JWT")
|
||||
return nil
|
||||
}
|
||||
if time.Now().Unix() >= expirationTime.Unix() {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("expired JWT")
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("invalid JWT")
|
||||
return nil
|
||||
}
|
||||
|
||||
m, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
c.Status(fiber.StatusInternalServerError).SendString("broken token claims")
|
||||
return nil
|
||||
}
|
||||
|
||||
id, ok := m["id"].(string)
|
||||
if !ok || id == "" {
|
||||
c.Status(fiber.StatusUnauthorized).SendString("missing id claim in JWT")
|
||||
return nil
|
||||
}
|
||||
|
||||
c.Context().SetUserValue(AccountId, id)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func GetAccountId(c *fiber.Ctx) (string, bool) {
|
||||
id, ok := c.Context().UserValue(AccountId).(string)
|
||||
return id, ok
|
||||
}
|
Loading…
Reference in New Issue
Block a user