From 8a8e45c62dd50b137e4f5674e98a97b02580c41f Mon Sep 17 00:00:00 2001 From: Lorain <87760338+justlorain@users.noreply.github.com> Date: Sat, 28 Jan 2023 00:37:18 +0800 Subject: [PATCH] feat: support parse token from post form (#15) --- auth_jwt.go | 25 +++++++++++++++++++++---- auth_jwt_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/auth_jwt.go b/auth_jwt.go index 494e5bd..0dbe94f 100644 --- a/auth_jwt.go +++ b/auth_jwt.go @@ -116,6 +116,8 @@ type HertzJWTMiddleware struct { // - "header:" // - "query:" // - "cookie:" + // - "param:" + // - "form:" TokenLookup string // TokenHeadName is a string in the header. Default value is "Bearer" @@ -230,6 +232,9 @@ var ( // ErrEmptyParamToken can be thrown if authing with parameter in path, the parameter in path is empty ErrEmptyParamToken = errors.New("parameter token is empty") + // ErrEmptyFormToken can be thrown if authing with post form, the form token is empty + ErrEmptyFormToken = errors.New("form token is empty") + // ErrInvalidSigningAlgorithm indicates signing algorithm is invalid, needs to be HS256, HS384, HS512, RS256, RS384 or RS512 ErrInvalidSigningAlgorithm = errors.New("invalid signing algorithm") @@ -664,7 +669,7 @@ func (mw *HertzJWTMiddleware) TokenGenerator(data interface{}) (string, time.Tim return tokenString, expire, nil } -func (mw *HertzJWTMiddleware) jwtFromHeader(ctx context.Context, c *app.RequestContext, key string) (string, error) { +func (mw *HertzJWTMiddleware) jwtFromHeader(_ context.Context, c *app.RequestContext, key string) (string, error) { authHeader := c.Request.Header.Get(key) if authHeader == "" { @@ -680,7 +685,7 @@ func (mw *HertzJWTMiddleware) jwtFromHeader(ctx context.Context, c *app.RequestC return parts[len(parts)-1], nil } -func (mw *HertzJWTMiddleware) jwtFromQuery(ctx context.Context, c *app.RequestContext, key string) (string, error) { +func (mw *HertzJWTMiddleware) jwtFromQuery(_ context.Context, c *app.RequestContext, key string) (string, error) { token := c.Query(key) if token == "" { @@ -690,7 +695,7 @@ func (mw *HertzJWTMiddleware) jwtFromQuery(ctx context.Context, c *app.RequestCo return token, nil } -func (mw *HertzJWTMiddleware) jwtFromCookie(ctx context.Context, c *app.RequestContext, key string) (string, error) { +func (mw *HertzJWTMiddleware) jwtFromCookie(_ context.Context, c *app.RequestContext, key string) (string, error) { cookie := string(c.Cookie(key)) if cookie == "" { @@ -700,7 +705,7 @@ func (mw *HertzJWTMiddleware) jwtFromCookie(ctx context.Context, c *app.RequestC return cookie, nil } -func (mw *HertzJWTMiddleware) jwtFromParam(ctx context.Context, c *app.RequestContext, key string) (string, error) { +func (mw *HertzJWTMiddleware) jwtFromParam(_ context.Context, c *app.RequestContext, key string) (string, error) { token := c.Param(key) if token == "" { @@ -710,6 +715,16 @@ func (mw *HertzJWTMiddleware) jwtFromParam(ctx context.Context, c *app.RequestCo return token, nil } +func (mw *HertzJWTMiddleware) jwtFromForm(_ context.Context, c *app.RequestContext, key string) (string, error) { + token := c.PostForm(key) + + if token == "" { + return "", ErrEmptyFormToken + } + + return token, nil +} + // ParseToken parse jwt token from hertz context func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestContext) (*jwt.Token, error) { var token string @@ -732,6 +747,8 @@ func (mw *HertzJWTMiddleware) ParseToken(ctx context.Context, c *app.RequestCont token, err = mw.jwtFromCookie(ctx, c, v) case "param": token, err = mw.jwtFromParam(ctx, c, v) + case "form": + token, err = mw.jwtFromForm(ctx, c, v) } } diff --git a/auth_jwt_test.go b/auth_jwt_test.go index 51c1476..5e1f662 100644 --- a/auth_jwt_test.go +++ b/auth_jwt_test.go @@ -247,6 +247,7 @@ func hertzHandler(auth *HertzJWTMiddleware) *route.Engine { group.Use(auth.MiddlewareFunc()) { group.GET("/hello", helloHandler) + group.POST("/hello", helloHandler) } return r @@ -365,6 +366,29 @@ func TestParseToken(t *testing.T) { assert.DeepEqual(t, http.StatusOK, w.Code) } +func TestParseTokenWithFrom(t *testing.T) { + // the middleware to test + authMiddleware, _ := New(&HertzJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + MaxRefresh: time.Hour * 24, + Authenticator: defaultAuthenticator, + TokenLookup: "form:Authorization", + }) + + handler := hertzHandler(authMiddleware) + + w := ut.PerformRequest(handler, http.MethodPost, "/auth/hello", &ut.Body{ + Body: bytes.NewBufferString("Authorization=" + makeTokenString("HS256", "admin")), + Len: -1, + }, ut.Header{ + Key: "Content-Type", + Value: "application/x-www-form-urlencoded", + }) + assert.DeepEqual(t, http.StatusOK, w.Code) +} + func TestParseTokenRS256(t *testing.T) { // the middleware to test authMiddleware, _ := New(&HertzJWTMiddleware{