diff --git a/_example/main.go b/_example/main.go index 116dee7..75e9f0b 100644 --- a/_example/main.go +++ b/_example/main.go @@ -66,6 +66,18 @@ import ( "github.com/go-chi/jwtauth/v5" ) +type dynamicTokenAuth struct { + keySet []byte +} + +func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) { + keySet, err := jwtauth.NewKeySet(d.keySet) + if err != nil { + return nil, err + } + return keySet, nil +} + var tokenAuth *jwtauth.JWTAuth func init() { @@ -74,7 +86,8 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) - fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) + fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString) + fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate) } func main() { @@ -103,6 +116,23 @@ func router() http.Handler { }) }) + r.Group(func(r chi.Router) { + dynamicTokenAuth := dynamicTokenAuth{keySet: keySet} + // Seek, verify and validate JWT tokens based on keys returned by the callback function + r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth)) + + // Handle valid / invalid tokens. In this example, we use + // the provided authenticator middleware, but you can write your + // own very easily, look at the Authenticator method in jwtauth.go + // and tweak it, its not scary. + r.Use(jwtauth.Authenticator) + + r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) { + _, claims, _ := jwtauth.FromContext(r.Context()) + w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"]))) + }) + }) + // Public routes r.Group(func(r chi.Router) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -112,3 +142,20 @@ func router() http.Handler { return r } + +var ( + keySet = []byte(`{ + "keys": [ + { + "kty": "RSA", + "alg": "RS256", + "kid": "kid", + "use": "sig", + "n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w", + "e": "AQAB" + } + ] +}`) + + sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA` +) diff --git a/go.mod b/go.mod index 7e0dad8..9e8be69 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.15 require ( github.com/go-chi/chi/v5 v5.0.4 - github.com/lestrrat-go/jwx v1.2.6 + github.com/lestrrat-go/jwx/v2 v2.0.1 ) diff --git a/go.sum b/go.sum index d93a856..3ab2ebc 100644 --- a/go.sum +++ b/go.sum @@ -1,66 +1,40 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d h1:1iy2qD6JEhHKKhUOA9IWs7mjco7lnw2qx8FsRI2wirE= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d/go.mod h1:tmAIfUFEirG/Y8jhZ9M+h36obRZAk/1fcSpXwAVlfqE= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= github.com/go-chi/chi/v5 v5.0.4 h1:5e494iHzsYBiyXQAHHuI4tyJS9M3V84OuX3ufIIGHFo= github.com/go-chi/chi/v5 v5.0.4/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= -github.com/goccy/go-json v0.7.6 h1:H0wq4jppBQ+9222sk5+hPLL25abZQiRuQ6YPnjO9c+A= -github.com/goccy/go-json v0.7.6/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= -github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y= -github.com/lestrrat-go/blackmagic v1.0.0 h1:XzdxDbuQTz0RZZEmdU7cnQxUtFUzgCSPq8RCz4BxIi4= -github.com/lestrrat-go/blackmagic v1.0.0/go.mod h1:TNgH//0vYSs8VXDCfkZLgIrVTTXQELZffUV0tz3MtdQ= -github.com/lestrrat-go/codegen v1.0.1/go.mod h1:JhJw6OQAuPEfVKUCLItpaVLumDGWQznd1VaXrBk9TdM= -github.com/lestrrat-go/httpcc v1.0.0 h1:FszVC6cKfDvBKcJv646+lkh4GydQg2Z29scgUfkOpYc= -github.com/lestrrat-go/httpcc v1.0.0/go.mod h1:tGS/u00Vh5N6FHNkExqGGNId8e0Big+++0Gf8MBnAvE= -github.com/lestrrat-go/iter v1.0.1 h1:q8faalr2dY6o8bV45uwrxq12bRa1ezKrB6oM9FUgN4A= -github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc= -github.com/lestrrat-go/jwx v1.2.6 h1:XAgfuHaOB7fDZ/6WhVgl8K89af768dU+3Nx4DlTbLIk= -github.com/lestrrat-go/jwx v1.2.6/go.mod h1:tJuGuAI3LC71IicTx82Mz1n3w9woAs2bYJZpkjJQ5aU= +github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM= +github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/lestrrat-go/blackmagic v1.0.1 h1:lS5Zts+5HIC/8og6cGHb0uCcNCa3OUt1ygh3Qz2Fe80= +github.com/lestrrat-go/blackmagic v1.0.1/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.1 h1:Cnc4NxIySph38pQPzKbjg5OkKsGR/Cf5xcWt5OlSUDI= +github.com/lestrrat-go/httprc v1.0.1/go.mod h1:5Ml+nB++j6IC0e6LzefJnrpMQDKgDwDCaIQQzhbqhJM= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.0.1 h1:BFhFnElL3HVa/e1sXTogmKbMlY2HgfEP1fozVc6/eYA= +github.com/lestrrat-go/jwx/v2 v2.0.1/go.mod h1:xV8+xRcrKbmnScV8adOzUuuTrL8aAZJoY4q2JAqIYU8= github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4= github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201217014255-9d1352758620 h1:3wPMTskHO3+O6jqTEXyFcsnuxMQOqYSaHsDxcbUXpqA= -golang.org/x/crypto v0.0.0-20201217014255-9d1352758620/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= +golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= -golang.org/x/tools v0.0.0-20210114065538-d78b04bdf963/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jwtauth.go b/jwtauth.go index e6534e8..883582f 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -2,13 +2,16 @@ package jwtauth import ( "context" + "encoding/json" "errors" + "fmt" "net/http" "strings" "time" - "github.com/lestrrat-go/jwx/jwa" - "github.com/lestrrat-go/jwx/jwt" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" ) type JWTAuth struct { @@ -16,6 +19,7 @@ type JWTAuth struct { signKey interface{} // private-key verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms verifier jwt.ParseOption + keySet jwk.Set } var ( @@ -35,15 +39,29 @@ var ( func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth { ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey} + algorithm := jwa.KeyAlgorithmFrom(ja.alg) if ja.verifyKey != nil { - ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey) + ja.verifier = jwt.WithKey(algorithm, ja.verifyKey) } else { - ja.verifier = jwt.WithVerify(ja.alg, ja.signKey) + ja.verifier = jwt.WithKey(algorithm, ja.signKey) } return ja } +func NewKeySet(keySet []byte) (*JWTAuth, error) { + ks := jwk.NewSet() + err := json.Unmarshal(keySet, &ks) + if err != nil { + return nil, err + } + + ja := &JWTAuth{keySet: ks} + ja.verifier = jwt.WithKeySet(ks) + + return ja, nil +} + // Verifier http middleware handler will verify a JWT string from a http request. // // Verifier will search for a JWT token in a http request, in the order: @@ -64,6 +82,26 @@ func Verifier(ja *JWTAuth) func(http.Handler) http.Handler { return Verify(ja, TokenFromHeader, TokenFromCookie) } +// VerifierDynamic http middleware handler will verify a JWT string from a http request. +// +// Verifier will search for a JWT token in a http request, in the order: +// 1. 'jwt' URI query parameter +// 2. 'Authorization: BEARER T' request header +// 3. Cookie 'jwt' value +// +// The first JWT string that is found as a query parameter, authorization header +// or cookie header is then decoded by the `jwt-go` library and a *jwt.Token +// object is set on the request context. In the case of a signature decoding error +// the Verifier will also set the error on the request context. +// +// The Verifier always calls the next http handler in sequence, which can either +// be the generic `jwtauth.Authenticator` middleware or your own custom handler +// which checks the request context jwt token and error to prepare a custom +// http response. +func VerifierDynamic(jaf func() (*JWTAuth, error)) func(http.Handler) http.Handler { + return VerifyDynamic(jaf, TokenFromHeader, TokenFromCookie) +} + func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { hfn := func(w http.ResponseWriter, r *http.Request) { @@ -76,6 +114,23 @@ func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http } } +func VerifyDynamic(jaf func() (*JWTAuth, error), findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + hfn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ja, err := jaf() + if err != nil { + ctx = NewContext(ctx, nil, err) + next.ServeHTTP(w, r.WithContext(ctx)) + } + token, err := VerifyRequest(ja, r, findTokenFns...) + ctx = NewContext(ctx, token, err) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(hfn) + } +} + func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) { var tokenString string @@ -115,6 +170,10 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) { } func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) { + if ja.keySet != nil { + return nil, "", fmt.Errorf("encode not supported") + } + t = jwt.New() for k, v := range claims { t.Set(k, v) @@ -132,7 +191,7 @@ func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) { } func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) { - return jwt.Sign(token, ja.alg, ja.signKey) + return jwt.Sign(token, jwt.WithKey(ja.alg, ja.signKey)) } func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) { @@ -143,11 +202,11 @@ func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) { // jwt library func ErrorReason(err error) error { switch err.Error() { - case "exp not satisfied", ErrExpired.Error(): + case jwt.ErrTokenExpired().Error(), ErrExpired.Error(): return ErrExpired - case "iat not satisfied", ErrIATInvalid.Error(): + case jwt.ErrInvalidIssuedAt().Error(), ErrIATInvalid.Error(): return ErrIATInvalid - case "nbf not satisfied", ErrNBFInvalid.Error(): + case jwt.ErrTokenNotYetValid().Error(), ErrNBFInvalid.Error(): return ErrNBFInvalid default: return ErrUnauthorized diff --git a/jwtauth_test.go b/jwtauth_test.go index dd2bc57..ea174c4 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -5,6 +5,8 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jws" "io" "io/ioutil" "log" @@ -16,7 +18,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" - "github.com/lestrrat-go/jwx/jwt" + "github.com/lestrrat-go/jwx/v2/jwt" ) var ( @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ== -----END PUBLIC KEY----- ` + + KeySet = `{ + "keys": [ + { + "kty": "RSA", + "n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ", + "e": "AQAB", + "alg": "RS256", + "kid": "1", + "use": "sig" + }, + { + "kty": "RSA", + "n": "foo", + "e": "AQAB", + "alg": "RS256", + "kid": "2", + "use": "sig" + } + ] +}` ) func init() { @@ -51,6 +74,18 @@ func init() { // Tests // +func TestNewKeySet(t *testing.T) { + _, err := jwtauth.NewKeySet([]byte("not a valid key set")) + if err == nil { + t.Fatal("The error should not be nil") + } + + _, err = jwtauth.NewKeySet([]byte(KeySet)) + if err != nil { + t.Fatalf(err.Error()) + } +} + func TestSimpleRSA(t *testing.T) { privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) @@ -98,6 +133,45 @@ func TestSimpleRSA(t *testing.T) { } } +func TestKeySetRSA(t *testing.T) { + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + + if err != nil { + t.Fatalf(err.Error()) + } + + KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet)) + claims := map[string]interface{}{ + "key": "val", + "key2": "val2", + "key3": "val3", + } + + signed := newJwtToken(jwa.RS256, privateKey, "1", claims) + + token, err := KeySetAuth.Decode(signed) + + if err != nil { + t.Fatalf("Failed to decode token string %s\n", err.Error()) + } + + tokenClaims, err := token.AsMap(context.Background()) + if err != nil { + t.Fatal(err.Error()) + } + + if !reflect.DeepEqual(claims, tokenClaims) { + t.Fatalf("The decoded claims don't match the original ones\n") + } + + _, _, err = KeySetAuth.Encode(claims) + if err.Error() != "encode not supported" { + t.Fatalf("Expect error to equal %s. Found: %s.", "encode not supported", err.Error()) + } +} + func TestSimple(t *testing.T) { r := chi.NewRouter() @@ -116,7 +190,7 @@ func TestSimple(t *testing.T) { } h := http.Header{} - h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{})) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS256, []byte("wrong"), "", map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } @@ -125,12 +199,12 @@ func TestSimple(t *testing.T) { t.Fatalf(resp) } // wrong token secret and wrong alg - h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{})) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS512, []byte("wrong"), "", map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } // correct token secret but wrong alg - h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{})) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS512, TokenSecret, "", map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } @@ -196,7 +270,7 @@ func TestMore(t *testing.T) { } h := http.Header{} - h.Set("Authorization", "BEARER "+newJwtToken([]byte("wrong"), map[string]interface{}{})) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS256, []byte("wrong"), "", map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } @@ -205,12 +279,12 @@ func TestMore(t *testing.T) { t.Fatalf(resp) } // wrong token secret and wrong alg - h.Set("Authorization", "BEARER "+newJwt512Token([]byte("wrong"), map[string]interface{}{})) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS512, []byte("wrong"), "", map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } // correct token secret but wrong alg - h.Set("Authorization", "BEARER "+newJwt512Token(TokenSecret, map[string]interface{}{})) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS512, TokenSecret, "", map[string]interface{}{})) if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { t.Fatalf(resp) } @@ -231,6 +305,99 @@ func TestMore(t *testing.T) { } } +func TestDynamic(t *testing.T) { + anotherKeySet := `{ + "keys": [ + { + "kty": "RSA", + "n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ", + "e": "AQAB", + "alg": "RS256", + "kid": "anotherKID", + "use": "sig" + } + ] +}` + + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + t.Fatalf(err.Error()) + } + + r := chi.NewRouter() + + keySet := []byte(KeySet) + keySetPtr := &keySet + + dynamicJWTAuthFunc := func() (*jwtauth.JWTAuth, error) { + keySet, err := jwtauth.NewKeySet(*keySetPtr) + if err != nil { + return nil, err + } + return keySet, nil + } + + // Protected routes + r.Group(func(r chi.Router) { + r.Use(jwtauth.VerifierDynamic(dynamicJWTAuthFunc)) + + authenticator := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, _, err := jwtauth.FromContext(r.Context()) + + if err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) + return + } + + if err := jwt.Validate(token); err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) + } + r.Use(authenticator) + + r.Get("/admin", func(w http.ResponseWriter, r *http.Request) { + _, claims, err := jwtauth.FromContext(r.Context()) + + if err != nil { + w.Write([]byte(fmt.Sprintf("error! %v", err))) + return + } + + w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"]))) + }) + }) + + // Public routes + r.Group(func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("welcome")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + h := http.Header{} + h.Set("Authorization", "BEARER "+newJwtToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" { + t.Fatalf(resp) + } + + // dynamically modifying JWTAuth return so kid 1 is no longer supported + *keySetPtr = []byte(anotherKeySet) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 401 || resp != "token is unauthorized\n" { + t.Fatalf(resp) + } +} + // // Test helper functions // @@ -262,29 +429,23 @@ func testRequest(t *testing.T, ts *httptest.Server, method, path string, header return resp.StatusCode, string(respBody) } -func newJwtToken(secret []byte, claims ...map[string]interface{}) string { +func newJwtToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string { token := jwt.New() if len(claims) > 0 { for k, v := range claims[0] { token.Set(k, v) } } - tokenPayload, err := jwt.Sign(token, "HS256", secret) - if err != nil { - log.Fatal(err) - } - return string(tokenPayload) -} -func newJwt512Token(secret []byte, claims ...map[string]interface{}) string { - // use-case: when token is signed with a different alg than expected - token := jwt.New() - if len(claims) > 0 { - for k, v := range claims[0] { - token.Set(k, v) + headers := jws.NewHeaders() + if kid != "" { + err := headers.Set("kid", kid) + if err != nil { + log.Fatal(err) } } - tokenPayload, err := jwt.Sign(token, "HS512", secret) + + tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers))) if err != nil { log.Fatal(err) } @@ -293,6 +454,6 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string { func newAuthHeader(claims ...map[string]interface{}) http.Header { h := http.Header{} - h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...)) + h.Set("Authorization", "BEARER "+newJwtToken(jwa.HS256, TokenSecret, "", claims...)) return h }