diff --git a/go.mod b/go.mod index ad968e2..1570e2e 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/rs/xid v1.5.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect diff --git a/go.sum b/go.sum index 3f99893..870d83a 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..cb5f83f --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,101 @@ +package middleware + +import ( + "github.com/gofiber/fiber/v2" + "os" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rs/xid" +) + +type ContextKey string + +const ( + SessionKey = "X-SessionKey" + SessionCookie = "session" + AccountId = "id" +) + +func AnswererChain() fiber.Handler { + return func(c *fiber.Ctx) error { + session := c.Get(SessionKey) + + if session == "" { + session := xid.New().String() + c.Set(SessionKey, session) + c.Locals(ContextKey(SessionKey), session) + } else { + c.Locals(ContextKey(SessionKey), session) + } + + return c.Next() + } +} + +func JWTAuth() fiber.Handler { + return func(c *fiber.Ctx) error { + authHeader := c.Get("Authorization") + if authHeader == "" { + c.Status(fiber.StatusUnauthorized).SendString("no JWT found") + return nil + } + + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + if tokenString == authHeader { + c.Status(fiber.StatusUnauthorized).SendString("invalid JWT Header: missing Bearer") + return nil + } + + publicKey := os.Getenv("PUBLIC_ACCESS_SECRET_KEY") + if publicKey == "" { + // TODO log + c.Status(fiber.StatusInternalServerError).SendString("public key not found") + return nil + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey)) + }) + if err != nil { + c.Status(fiber.StatusUnauthorized).SendString("invalid JWT") + return nil + } + + if token.Valid { + expirationTime, err := token.Claims.GetExpirationTime() + if err != nil { + c.Status(fiber.StatusUnauthorized).SendString("no expiration time in JWT") + return nil + } + if time.Now().Unix() >= expirationTime.Unix() { + c.Status(fiber.StatusUnauthorized).SendString("expired JWT") + return nil + } + } else { + c.Status(fiber.StatusUnauthorized).SendString("invalid JWT") + return nil + } + + m, ok := token.Claims.(jwt.MapClaims) + if !ok { + c.Status(fiber.StatusInternalServerError).SendString("broken token claims") + return nil + } + + id, ok := m["id"].(string) + if !ok || id == "" { + c.Status(fiber.StatusUnauthorized).SendString("missing id claim in JWT") + return nil + } + + c.Context().SetUserValue(AccountId, id) + return c.Next() + } +} + +func GetAccountId(c *fiber.Ctx) (string, bool) { + id, ok := c.Context().UserValue(AccountId).(string) + return id, ok +}