From b00c21698439ffcab69e899580f493bc01c58682 Mon Sep 17 00:00:00 2001 From: Sergey Beryozkin Date: Mon, 15 Apr 2024 19:18:06 +0100 Subject: [PATCH] Check configured signing algorithm when reading the key (#785) --- .../jwt/build/impl/JwtClaimsBuilderImpl.java | 2 +- .../jwt/build/impl/JwtEncryptionImpl.java | 44 +++++++++++-------- .../jwt/build/impl/JwtSignatureImpl.java | 33 ++++++++------ .../jwt/build/JwtBuildConfigSource.java | 2 + .../io/smallrye/jwt/build/JwtEncryptTest.java | 26 +++++++++++ .../io/smallrye/jwt/build/JwtSignTest.java | 24 ++++++++++ 6 files changed, 98 insertions(+), 33 deletions(-) diff --git a/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtClaimsBuilderImpl.java b/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtClaimsBuilderImpl.java index 151f7903..86e73534 100644 --- a/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtClaimsBuilderImpl.java +++ b/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtClaimsBuilderImpl.java @@ -202,7 +202,7 @@ public JwtSignatureBuilder jws() { */ @Override public JwtSignatureBuilder header(String name, Object value) { - if ("alg".equals(name)) { + if (HeaderParameterNames.ALGORITHM.equals(name)) { return algorithm(toSignatureAlgorithm((String) value)); } else { headers.put(name, value); diff --git a/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtEncryptionImpl.java b/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtEncryptionImpl.java index b630c80a..c621414a 100644 --- a/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtEncryptionImpl.java +++ b/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtEncryptionImpl.java @@ -13,6 +13,7 @@ import org.jose4j.jwa.AlgorithmConstraints; import org.jose4j.jwe.JsonWebEncryption; import org.jose4j.jwk.JsonWebKey; +import org.jose4j.jwx.HeaderParameterNames; import io.smallrye.jwt.algorithm.ContentEncryptionAlgorithm; import io.smallrye.jwt.algorithm.KeyEncryptionAlgorithm; @@ -124,9 +125,9 @@ public String encryptWithSecret(String secret) throws JwtEncryptionException { */ @Override public JwtEncryptionBuilder header(String name, Object value) { - if ("alg".equals(name)) { + if (HeaderParameterNames.ALGORITHM.equals(name)) { return keyAlgorithm(toKeyEncryptionAlgorithm((String) value)); - } else if ("enc".equals(name)) { + } else if (HeaderParameterNames.ENCRYPTION_METHOD.equals(name)) { return contentAlgorithm(toContentEncryptionAlgorithm((String) value)); } else { headers.put(name, value); @@ -139,7 +140,7 @@ public JwtEncryptionBuilder header(String name, Object value) { */ @Override public JwtEncryptionBuilder keyAlgorithm(KeyEncryptionAlgorithm algorithm) { - headers.put("alg", algorithm.getAlgorithm()); + headers.put(HeaderParameterNames.ALGORITHM, algorithm.getAlgorithm()); return this; } @@ -148,7 +149,7 @@ public JwtEncryptionBuilder keyAlgorithm(KeyEncryptionAlgorithm algorithm) { */ @Override public JwtEncryptionBuilder contentAlgorithm(ContentEncryptionAlgorithm algorithm) { - headers.put("enc", algorithm.getAlgorithm()); + headers.put(HeaderParameterNames.ENCRYPTION_METHOD, algorithm.getAlgorithm()); return this; } @@ -157,7 +158,7 @@ public JwtEncryptionBuilder contentAlgorithm(ContentEncryptionAlgorithm algorith */ @Override public JwtEncryptionBuilder keyId(String keyId) { - headers.put("kid", keyId); + headers.put(HeaderParameterNames.KEY_ID, keyId); return this; } @@ -171,8 +172,8 @@ private String encryptInternal(Key key) { for (Map.Entry entry : headers.entrySet()) { jwe.getHeaders().setObjectHeaderValue(entry.getKey(), entry.getValue()); } - if (innerSigned && !headers.containsKey("cty")) { - jwe.getHeaders().setObjectHeaderValue("cty", "JWT"); + if (innerSigned && !headers.containsKey(HeaderParameterNames.CONTENT_TYPE)) { + jwe.getHeaders().setObjectHeaderValue(HeaderParameterNames.CONTENT_TYPE, "JWT"); } String keyAlgorithm = getKeyEncryptionAlgorithm(key); jwe.setAlgorithmConstraints(new AlgorithmConstraints(AlgorithmConstraints.ConstraintType.PERMIT, keyAlgorithm)); @@ -193,18 +194,24 @@ private boolean isRelaxKeyValidation() { return JwtBuildUtils.getConfigProperty(JwtBuildUtils.ENC_KEY_RELAX_VALIDATION_PROPERTY, Boolean.class, false); } - private String getKeyEncryptionAlgorithm(Key keyEncryptionKey) { - String alg = (String) headers.get("alg"); + private String getConfiguredKeyEncryptionAlgorithm() { + String alg = (String) headers.get(HeaderParameterNames.ALGORITHM); if (alg == null) { try { alg = JwtBuildUtils.getConfigProperty(JwtBuildUtils.NEW_TOKEN_KEY_ENCRYPTION_ALG_PROPERTY, String.class); if (alg != null) { alg = KeyEncryptionAlgorithm.fromAlgorithm(alg).getAlgorithm(); + headers.put(HeaderParameterNames.ALGORITHM, alg); } } catch (Exception ex) { throw ImplMessages.msg.unsupportedKeyEncryptionAlgorithm(alg); } } + return alg; + } + + private String getKeyEncryptionAlgorithm(Key keyEncryptionKey) { + String alg = getConfiguredKeyEncryptionAlgorithm(); if (keyEncryptionKey instanceof RSAPublicKey) { if (alg == null) { @@ -233,7 +240,7 @@ private static boolean isXecPublicKey(Key encKey) { } private String getContentEncryptionAlgorithm() { - String alg = (String) headers.get("enc"); + String alg = (String) headers.get(HeaderParameterNames.ENCRYPTION_METHOD); if (alg == null) { try { alg = JwtBuildUtils.getConfigProperty(JwtBuildUtils.NEW_TOKEN_CONTENT_ENCRYPTION_ALG_PROPERTY, String.class); @@ -256,18 +263,17 @@ private static String getKeyContentFromLocation(String keyLocation) { } Key getEncryptionKeyFromKeyContent(String keyContent) { - String kid = (String) headers.get("kid"); - String algHeader = (String) headers.get("alg"); + String kid = (String) headers.get(HeaderParameterNames.KEY_ID); + String alg = getConfiguredKeyEncryptionAlgorithm(); // Try PEM format first - default to RSA_OAEP_256 if no algorithm header is set Key key = KeyUtils.tryAsPemEncryptionPublicKey(keyContent, - (algHeader == null ? KeyEncryptionAlgorithm.RSA_OAEP_256 - : KeyEncryptionAlgorithm.fromAlgorithm(algHeader))); + (alg == null ? KeyEncryptionAlgorithm.RSA_OAEP_256 : KeyEncryptionAlgorithm.fromAlgorithm(alg))); if (key == null) { if (kid == null) { kid = JwtBuildUtils.getConfigProperty(JwtBuildUtils.ENC_KEY_ID_PROPERTY, String.class); if (kid != null) { - headers.put("kid", kid); + headers.put(HeaderParameterNames.KEY_ID, kid); } } // Try to load JWK from a single JWK resource or JWK set resource @@ -275,15 +281,15 @@ Key getEncryptionKeyFromKeyContent(String keyContent) { if (jwk != null) { // if the user has already set the algorithm header then JWK `alg` header, if set, must match it key = KeyUtils.getPublicOrSecretEncryptingKey(jwk, - (algHeader == null ? null : KeyEncryptionAlgorithm.fromAlgorithm(algHeader))); + (alg == null ? null : KeyEncryptionAlgorithm.fromAlgorithm(alg))); if (key != null) { // if the algorithm header is not set then use JWK `alg` - if (algHeader == null && jwk.getAlgorithm() != null) { - headers.put("alg", jwk.getAlgorithm()); + if (alg == null && jwk.getAlgorithm() != null) { + headers.put(HeaderParameterNames.ALGORITHM, jwk.getAlgorithm()); } // if 'kid' header is not set then use JWK `kid` if (kid == null && jwk.getKeyId() != null) { - headers.put("kid", jwk.getKeyId()); + headers.put(HeaderParameterNames.KEY_ID, jwk.getKeyId()); } } } diff --git a/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtSignatureImpl.java b/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtSignatureImpl.java index 5e8fd35e..1e2d13f2 100644 --- a/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtSignatureImpl.java +++ b/implementation/jwt-build/src/main/java/io/smallrye/jwt/build/impl/JwtSignatureImpl.java @@ -14,6 +14,7 @@ import org.jose4j.jwk.JsonWebKey; import org.jose4j.jws.JsonWebSignature; import org.jose4j.jwt.JwtClaims; +import org.jose4j.jwx.HeaderParameterNames; import io.smallrye.jwt.algorithm.SignatureAlgorithm; import io.smallrye.jwt.build.JwtEncryptionBuilder; @@ -169,8 +170,8 @@ private String signInternal(Key signingKey) { for (Map.Entry entry : headers.entrySet()) { jws.setHeader(entry.getKey(), entry.getValue()); } - if (!headers.containsKey("typ")) { - jws.setHeader("typ", "JWT"); + if (!headers.containsKey(HeaderParameterNames.TYPE)) { + jws.setHeader(HeaderParameterNames.TYPE, "JWT"); } String algorithm = getSignatureAlgorithm(signingKey); @@ -193,18 +194,24 @@ private boolean isRelaxKeyValidation() { return JwtBuildUtils.getConfigProperty(JwtBuildUtils.SIGN_KEY_RELAX_VALIDATION_PROPERTY, Boolean.class, false); } - private String getSignatureAlgorithm(Key signingKey) { - String alg = (String) headers.get("alg"); + private String getConfiguredSignatureAlgorithm() { + String alg = (String) headers.get(HeaderParameterNames.ALGORITHM); if (alg == null) { try { alg = JwtBuildUtils.getConfigProperty(JwtBuildUtils.NEW_TOKEN_SIGNATURE_ALG_PROPERTY, String.class); if (alg != null) { alg = SignatureAlgorithm.valueOf(alg.toUpperCase()).getAlgorithm(); + headers.put(HeaderParameterNames.ALGORITHM, alg); } } catch (Exception ex) { throw ImplMessages.msg.unsupportedSignatureAlgorithm(alg); } } + return alg; + } + + private String getSignatureAlgorithm(Key signingKey) { + String alg = getConfiguredSignatureAlgorithm(); if ("none".equals(alg)) { throw ImplMessages.msg.noneSignatureAlgorithmUnsupported(); } @@ -255,17 +262,17 @@ static String getKeyContentFromLocation(String keyLocation) { } Key getSigningKeyFromKeyContent(String keyContent) { - String kid = (String) headers.get("kid"); - String algHeader = (String) headers.get("alg"); + String kid = (String) headers.get(HeaderParameterNames.KEY_ID); + String alg = getConfiguredSignatureAlgorithm(); - // Try PEM format first - default to RS256 if no algorithm header is set + // Try PEM format first - default to RS256 if the algorithm is unknown Key key = KeyUtils.tryAsPemSigningPrivateKey(keyContent, - (algHeader == null ? SignatureAlgorithm.RS256 : SignatureAlgorithm.fromAlgorithm(algHeader))); + (alg == null ? SignatureAlgorithm.RS256 : SignatureAlgorithm.fromAlgorithm(alg))); if (key == null) { if (kid == null) { kid = JwtBuildUtils.getConfigProperty(JwtBuildUtils.SIGN_KEY_ID_PROPERTY, String.class); if (kid != null) { - headers.put("kid", kid); + headers.put(HeaderParameterNames.KEY_ID, kid); } } @@ -274,15 +281,15 @@ Key getSigningKeyFromKeyContent(String keyContent) { if (jwk != null) { // if the user has already set the algorithm header then JWK `alg` header, if set, must match it key = KeyUtils.getPrivateOrSecretSigningKey(jwk, - (algHeader == null ? null : SignatureAlgorithm.fromAlgorithm(algHeader))); + (alg == null ? null : SignatureAlgorithm.fromAlgorithm(alg))); if (key != null) { // if the algorithm header is not set then use JWK `alg` - if (algHeader == null && jwk.getAlgorithm() != null) { - headers.put("alg", jwk.getAlgorithm()); + if (alg == null && jwk.getAlgorithm() != null) { + headers.put(HeaderParameterNames.ALGORITHM, jwk.getAlgorithm()); } // if 'kid' header is not set then use JWK `kid` if (kid == null && jwk.getKeyId() != null) { - headers.put("kid", jwk.getKeyId()); + headers.put(HeaderParameterNames.KEY_ID, jwk.getKeyId()); } } } diff --git a/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtBuildConfigSource.java b/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtBuildConfigSource.java index 99b4d182..5ffc87bd 100644 --- a/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtBuildConfigSource.java +++ b/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtBuildConfigSource.java @@ -22,6 +22,8 @@ public class JwtBuildConfigSource implements ConfigSource { JwtBuildUtils.NEW_TOKEN_AUDIENCE_PROPERTY, JwtBuildUtils.NEW_TOKEN_LIFESPAN_PROPERTY, JwtBuildUtils.NEW_TOKEN_OVERRIDE_CLAIMS_PROPERTY, + JwtBuildUtils.NEW_TOKEN_SIGNATURE_ALG_PROPERTY, + JwtBuildUtils.NEW_TOKEN_KEY_ENCRYPTION_ALG_PROPERTY, JwtBuildUtils.SIGN_KEYSTORE_KEY_ALIAS, JwtBuildUtils.ENC_KEYSTORE_KEY_ALIAS)); diff --git a/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtEncryptTest.java b/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtEncryptTest.java index b05c340c..ffddeafb 100644 --- a/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtEncryptTest.java +++ b/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtEncryptTest.java @@ -358,6 +358,32 @@ void encryptWithConfiguredEcKeyAndA128CBCHS256() throws Exception { checkJwtClaims(claims); } + @Test + void encryptWithConfiguredEcKeyAndAlgorithmAndA128CBCHS256() throws Exception { + JwtBuildConfigSource configSource = JwtSignTest.getConfigSource(); + configSource.setEncryptionKeyLocation("/ecPublicKey.pem"); + configSource.setKeyEncryptionAlgorithm("ECDH-ES+A256KW"); + String jweCompact = null; + try { + jweCompact = Jwt.claims() + .claim("customClaim", "custom-value") + .jwe() + .keyId("key-enc-key-id") + .contentAlgorithm(ContentEncryptionAlgorithm.A128CBC_HS256) + .encrypt(); + } finally { + configSource.setEncryptionKeyLocation("/publicKey.pem"); + configSource.setKeyEncryptionAlgorithm(null); + } + + checkJweHeaders(jweCompact, "ECDH-ES+A256KW", "A128CBC-HS256", 4); + + JsonWebEncryption jwe = getJsonWebEncryption(jweCompact, getEcPrivateKey()); + + JwtClaims claims = JwtClaims.parse(jwe.getPlaintextString()); + checkJwtClaims(claims); + } + @Test void encryptWithConfiguredEcKeyAndContentAlgorithm() throws Exception { JwtBuildConfigSource configSource = JwtSignTest.getConfigSource(); diff --git a/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtSignTest.java b/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtSignTest.java index 8fb8c0ff..c2b92ffd 100644 --- a/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtSignTest.java +++ b/implementation/jwt-build/src/test/java/io/smallrye/jwt/build/JwtSignTest.java @@ -892,6 +892,30 @@ void signClaimsEcKey() throws Exception { assertEquals("custom-value", claims.getClaimValue("customClaim")); } + @Test + void signClaimsEcKeyFileWithConfiguredAlgorithm() throws Exception { + JwtBuildConfigSource configSource = getConfigSource(); + configSource.setSigningKeyLocation("/ecPrivateKey.pem"); + configSource.setSignatureAlgorithm(SignatureAlgorithm.ES256.getAlgorithm()); + String jwt = null; + try { + jwt = Jwt.claim("customClaim", "custom-value") + .sign(); + } finally { + configSource.setSigningKeyLocation("/privateKey.pem"); + configSource.setSignatureAlgorithm(null); + } + + PublicKey ecKey = getEcPublicKey(); + JsonWebSignature jws = getVerifiedJws(jwt, ecKey); + JwtClaims claims = JwtClaims.parse(jws.getPayload()); + + assertEquals(4, claims.getClaimsMap().size()); + Map headers = getJwsHeaders(jwt, 2); + checkDefaultClaimsAndHeaders(headers, claims, "ES256", 300); + assertEquals("custom-value", claims.getClaimValue("customClaim")); + } + private static SecretKey createSecretKey() throws Exception { String jwkJson = "{\"kty\":\"oct\",\"k\":\"Fdh9u8rINxfivbrianbbVT1u232VQBZYKx1HGAGPt2I\"}"; JsonWebKey jwk = JsonWebKey.Factory.newJwk(jwkJson);