Skip to content

Commit

Permalink
Check configured signing algorithm when reading the key (#785)
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin authored Apr 15, 2024
1 parent 039c889 commit b00c216
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public JwtSignatureBuilder jws() {
*/
@Override
public JwtSignatureBuilder header(String name, Object value) {
if ("alg".equals(name)) {
if (HeaderParameterNames.ALGORITHM.equals(name)) {
return algorithm(toSignatureAlgorithm((String) value));
} else {
headers.put(name, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jwe.JsonWebEncryption;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwx.HeaderParameterNames;

import io.smallrye.jwt.algorithm.ContentEncryptionAlgorithm;
import io.smallrye.jwt.algorithm.KeyEncryptionAlgorithm;
Expand Down Expand Up @@ -124,9 +125,9 @@ public String encryptWithSecret(String secret) throws JwtEncryptionException {
*/
@Override
public JwtEncryptionBuilder header(String name, Object value) {
if ("alg".equals(name)) {
if (HeaderParameterNames.ALGORITHM.equals(name)) {
return keyAlgorithm(toKeyEncryptionAlgorithm((String) value));
} else if ("enc".equals(name)) {
} else if (HeaderParameterNames.ENCRYPTION_METHOD.equals(name)) {
return contentAlgorithm(toContentEncryptionAlgorithm((String) value));
} else {
headers.put(name, value);
Expand All @@ -139,7 +140,7 @@ public JwtEncryptionBuilder header(String name, Object value) {
*/
@Override
public JwtEncryptionBuilder keyAlgorithm(KeyEncryptionAlgorithm algorithm) {
headers.put("alg", algorithm.getAlgorithm());
headers.put(HeaderParameterNames.ALGORITHM, algorithm.getAlgorithm());
return this;
}

Expand All @@ -148,7 +149,7 @@ public JwtEncryptionBuilder keyAlgorithm(KeyEncryptionAlgorithm algorithm) {
*/
@Override
public JwtEncryptionBuilder contentAlgorithm(ContentEncryptionAlgorithm algorithm) {
headers.put("enc", algorithm.getAlgorithm());
headers.put(HeaderParameterNames.ENCRYPTION_METHOD, algorithm.getAlgorithm());
return this;
}

Expand All @@ -157,7 +158,7 @@ public JwtEncryptionBuilder contentAlgorithm(ContentEncryptionAlgorithm algorith
*/
@Override
public JwtEncryptionBuilder keyId(String keyId) {
headers.put("kid", keyId);
headers.put(HeaderParameterNames.KEY_ID, keyId);
return this;
}

Expand All @@ -171,8 +172,8 @@ private String encryptInternal(Key key) {
for (Map.Entry<String, Object> entry : headers.entrySet()) {
jwe.getHeaders().setObjectHeaderValue(entry.getKey(), entry.getValue());
}
if (innerSigned && !headers.containsKey("cty")) {
jwe.getHeaders().setObjectHeaderValue("cty", "JWT");
if (innerSigned && !headers.containsKey(HeaderParameterNames.CONTENT_TYPE)) {
jwe.getHeaders().setObjectHeaderValue(HeaderParameterNames.CONTENT_TYPE, "JWT");
}
String keyAlgorithm = getKeyEncryptionAlgorithm(key);
jwe.setAlgorithmConstraints(new AlgorithmConstraints(AlgorithmConstraints.ConstraintType.PERMIT, keyAlgorithm));
Expand All @@ -193,18 +194,24 @@ private boolean isRelaxKeyValidation() {
return JwtBuildUtils.getConfigProperty(JwtBuildUtils.ENC_KEY_RELAX_VALIDATION_PROPERTY, Boolean.class, false);
}

private String getKeyEncryptionAlgorithm(Key keyEncryptionKey) {
String alg = (String) headers.get("alg");
private String getConfiguredKeyEncryptionAlgorithm() {
String alg = (String) headers.get(HeaderParameterNames.ALGORITHM);
if (alg == null) {
try {
alg = JwtBuildUtils.getConfigProperty(JwtBuildUtils.NEW_TOKEN_KEY_ENCRYPTION_ALG_PROPERTY, String.class);
if (alg != null) {
alg = KeyEncryptionAlgorithm.fromAlgorithm(alg).getAlgorithm();
headers.put(HeaderParameterNames.ALGORITHM, alg);
}
} catch (Exception ex) {
throw ImplMessages.msg.unsupportedKeyEncryptionAlgorithm(alg);
}
}
return alg;
}

private String getKeyEncryptionAlgorithm(Key keyEncryptionKey) {
String alg = getConfiguredKeyEncryptionAlgorithm();

if (keyEncryptionKey instanceof RSAPublicKey) {
if (alg == null) {
Expand Down Expand Up @@ -233,7 +240,7 @@ private static boolean isXecPublicKey(Key encKey) {
}

private String getContentEncryptionAlgorithm() {
String alg = (String) headers.get("enc");
String alg = (String) headers.get(HeaderParameterNames.ENCRYPTION_METHOD);
if (alg == null) {
try {
alg = JwtBuildUtils.getConfigProperty(JwtBuildUtils.NEW_TOKEN_CONTENT_ENCRYPTION_ALG_PROPERTY, String.class);
Expand All @@ -256,34 +263,33 @@ private static String getKeyContentFromLocation(String keyLocation) {
}

Key getEncryptionKeyFromKeyContent(String keyContent) {
String kid = (String) headers.get("kid");
String algHeader = (String) headers.get("alg");
String kid = (String) headers.get(HeaderParameterNames.KEY_ID);
String alg = getConfiguredKeyEncryptionAlgorithm();

// Try PEM format first - default to RSA_OAEP_256 if no algorithm header is set
Key key = KeyUtils.tryAsPemEncryptionPublicKey(keyContent,
(algHeader == null ? KeyEncryptionAlgorithm.RSA_OAEP_256
: KeyEncryptionAlgorithm.fromAlgorithm(algHeader)));
(alg == null ? KeyEncryptionAlgorithm.RSA_OAEP_256 : KeyEncryptionAlgorithm.fromAlgorithm(alg)));
if (key == null) {
if (kid == null) {
kid = JwtBuildUtils.getConfigProperty(JwtBuildUtils.ENC_KEY_ID_PROPERTY, String.class);
if (kid != null) {
headers.put("kid", kid);
headers.put(HeaderParameterNames.KEY_ID, kid);
}
}
// Try to load JWK from a single JWK resource or JWK set resource
JsonWebKey jwk = KeyUtils.getJwkKeyFromJwkSet(kid, keyContent);
if (jwk != null) {
// if the user has already set the algorithm header then JWK `alg` header, if set, must match it
key = KeyUtils.getPublicOrSecretEncryptingKey(jwk,
(algHeader == null ? null : KeyEncryptionAlgorithm.fromAlgorithm(algHeader)));
(alg == null ? null : KeyEncryptionAlgorithm.fromAlgorithm(alg)));
if (key != null) {
// if the algorithm header is not set then use JWK `alg`
if (algHeader == null && jwk.getAlgorithm() != null) {
headers.put("alg", jwk.getAlgorithm());
if (alg == null && jwk.getAlgorithm() != null) {
headers.put(HeaderParameterNames.ALGORITHM, jwk.getAlgorithm());
}
// if 'kid' header is not set then use JWK `kid`
if (kid == null && jwk.getKeyId() != null) {
headers.put("kid", jwk.getKeyId());
headers.put(HeaderParameterNames.KEY_ID, jwk.getKeyId());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwx.HeaderParameterNames;

import io.smallrye.jwt.algorithm.SignatureAlgorithm;
import io.smallrye.jwt.build.JwtEncryptionBuilder;
Expand Down Expand Up @@ -169,8 +170,8 @@ private String signInternal(Key signingKey) {
for (Map.Entry<String, Object> entry : headers.entrySet()) {
jws.setHeader(entry.getKey(), entry.getValue());
}
if (!headers.containsKey("typ")) {
jws.setHeader("typ", "JWT");
if (!headers.containsKey(HeaderParameterNames.TYPE)) {
jws.setHeader(HeaderParameterNames.TYPE, "JWT");
}

String algorithm = getSignatureAlgorithm(signingKey);
Expand All @@ -193,18 +194,24 @@ private boolean isRelaxKeyValidation() {
return JwtBuildUtils.getConfigProperty(JwtBuildUtils.SIGN_KEY_RELAX_VALIDATION_PROPERTY, Boolean.class, false);
}

private String getSignatureAlgorithm(Key signingKey) {
String alg = (String) headers.get("alg");
private String getConfiguredSignatureAlgorithm() {
String alg = (String) headers.get(HeaderParameterNames.ALGORITHM);
if (alg == null) {
try {
alg = JwtBuildUtils.getConfigProperty(JwtBuildUtils.NEW_TOKEN_SIGNATURE_ALG_PROPERTY, String.class);
if (alg != null) {
alg = SignatureAlgorithm.valueOf(alg.toUpperCase()).getAlgorithm();
headers.put(HeaderParameterNames.ALGORITHM, alg);
}
} catch (Exception ex) {
throw ImplMessages.msg.unsupportedSignatureAlgorithm(alg);
}
}
return alg;
}

private String getSignatureAlgorithm(Key signingKey) {
String alg = getConfiguredSignatureAlgorithm();
if ("none".equals(alg)) {
throw ImplMessages.msg.noneSignatureAlgorithmUnsupported();
}
Expand Down Expand Up @@ -255,17 +262,17 @@ static String getKeyContentFromLocation(String keyLocation) {
}

Key getSigningKeyFromKeyContent(String keyContent) {
String kid = (String) headers.get("kid");
String algHeader = (String) headers.get("alg");
String kid = (String) headers.get(HeaderParameterNames.KEY_ID);
String alg = getConfiguredSignatureAlgorithm();

// Try PEM format first - default to RS256 if no algorithm header is set
// Try PEM format first - default to RS256 if the algorithm is unknown
Key key = KeyUtils.tryAsPemSigningPrivateKey(keyContent,
(algHeader == null ? SignatureAlgorithm.RS256 : SignatureAlgorithm.fromAlgorithm(algHeader)));
(alg == null ? SignatureAlgorithm.RS256 : SignatureAlgorithm.fromAlgorithm(alg)));
if (key == null) {
if (kid == null) {
kid = JwtBuildUtils.getConfigProperty(JwtBuildUtils.SIGN_KEY_ID_PROPERTY, String.class);
if (kid != null) {
headers.put("kid", kid);
headers.put(HeaderParameterNames.KEY_ID, kid);
}
}

Expand All @@ -274,15 +281,15 @@ Key getSigningKeyFromKeyContent(String keyContent) {
if (jwk != null) {
// if the user has already set the algorithm header then JWK `alg` header, if set, must match it
key = KeyUtils.getPrivateOrSecretSigningKey(jwk,
(algHeader == null ? null : SignatureAlgorithm.fromAlgorithm(algHeader)));
(alg == null ? null : SignatureAlgorithm.fromAlgorithm(alg)));
if (key != null) {
// if the algorithm header is not set then use JWK `alg`
if (algHeader == null && jwk.getAlgorithm() != null) {
headers.put("alg", jwk.getAlgorithm());
if (alg == null && jwk.getAlgorithm() != null) {
headers.put(HeaderParameterNames.ALGORITHM, jwk.getAlgorithm());
}
// if 'kid' header is not set then use JWK `kid`
if (kid == null && jwk.getKeyId() != null) {
headers.put("kid", jwk.getKeyId());
headers.put(HeaderParameterNames.KEY_ID, jwk.getKeyId());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public class JwtBuildConfigSource implements ConfigSource {
JwtBuildUtils.NEW_TOKEN_AUDIENCE_PROPERTY,
JwtBuildUtils.NEW_TOKEN_LIFESPAN_PROPERTY,
JwtBuildUtils.NEW_TOKEN_OVERRIDE_CLAIMS_PROPERTY,
JwtBuildUtils.NEW_TOKEN_SIGNATURE_ALG_PROPERTY,
JwtBuildUtils.NEW_TOKEN_KEY_ENCRYPTION_ALG_PROPERTY,
JwtBuildUtils.SIGN_KEYSTORE_KEY_ALIAS,
JwtBuildUtils.ENC_KEYSTORE_KEY_ALIAS));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,32 @@ void encryptWithConfiguredEcKeyAndA128CBCHS256() throws Exception {
checkJwtClaims(claims);
}

@Test
void encryptWithConfiguredEcKeyAndAlgorithmAndA128CBCHS256() throws Exception {
JwtBuildConfigSource configSource = JwtSignTest.getConfigSource();
configSource.setEncryptionKeyLocation("/ecPublicKey.pem");
configSource.setKeyEncryptionAlgorithm("ECDH-ES+A256KW");
String jweCompact = null;
try {
jweCompact = Jwt.claims()
.claim("customClaim", "custom-value")
.jwe()
.keyId("key-enc-key-id")
.contentAlgorithm(ContentEncryptionAlgorithm.A128CBC_HS256)
.encrypt();
} finally {
configSource.setEncryptionKeyLocation("/publicKey.pem");
configSource.setKeyEncryptionAlgorithm(null);
}

checkJweHeaders(jweCompact, "ECDH-ES+A256KW", "A128CBC-HS256", 4);

JsonWebEncryption jwe = getJsonWebEncryption(jweCompact, getEcPrivateKey());

JwtClaims claims = JwtClaims.parse(jwe.getPlaintextString());
checkJwtClaims(claims);
}

@Test
void encryptWithConfiguredEcKeyAndContentAlgorithm() throws Exception {
JwtBuildConfigSource configSource = JwtSignTest.getConfigSource();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,30 @@ void signClaimsEcKey() throws Exception {
assertEquals("custom-value", claims.getClaimValue("customClaim"));
}

@Test
void signClaimsEcKeyFileWithConfiguredAlgorithm() throws Exception {
JwtBuildConfigSource configSource = getConfigSource();
configSource.setSigningKeyLocation("/ecPrivateKey.pem");
configSource.setSignatureAlgorithm(SignatureAlgorithm.ES256.getAlgorithm());
String jwt = null;
try {
jwt = Jwt.claim("customClaim", "custom-value")
.sign();
} finally {
configSource.setSigningKeyLocation("/privateKey.pem");
configSource.setSignatureAlgorithm(null);
}

PublicKey ecKey = getEcPublicKey();
JsonWebSignature jws = getVerifiedJws(jwt, ecKey);
JwtClaims claims = JwtClaims.parse(jws.getPayload());

assertEquals(4, claims.getClaimsMap().size());
Map<String, Object> headers = getJwsHeaders(jwt, 2);
checkDefaultClaimsAndHeaders(headers, claims, "ES256", 300);
assertEquals("custom-value", claims.getClaimValue("customClaim"));
}

private static SecretKey createSecretKey() throws Exception {
String jwkJson = "{\"kty\":\"oct\",\"k\":\"Fdh9u8rINxfivbrianbbVT1u232VQBZYKx1HGAGPt2I\"}";
JsonWebKey jwk = JsonWebKey.Factory.newJwk(jwkJson);
Expand Down

0 comments on commit b00c216

Please sign in to comment.