From fe55b15102a43a032337f6fb892e0412b3f8759a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Tue, 20 Aug 2024 09:51:15 +0200 Subject: [PATCH] feat: support custom TimeProvider when validating tokens (introspect, userinfo) * add verify function to OAuth2TokenProvider and use the TimeProvider if set - i.e. via overriding Nimbus DefaultJWTClaimsVerifier's currentTime function * refactor tests for simplicity --- .../mock/oauth2/introspect/Introspect.kt | 12 +- .../security/mock/oauth2/token/KeyProvider.kt | 89 ++--- .../mock/oauth2/token/OAuth2TokenProvider.kt | 318 ++++++++++-------- .../security/mock/oauth2/userinfo/UserInfo.kt | 10 +- .../mock/oauth2/introspect/IntrospectTest.kt | 25 ++ .../token/OAuth2TokenProviderRSATest.kt | 154 ++++----- .../mock/oauth2/userinfo/UserInfoTest.kt | 22 ++ 7 files changed, 343 insertions(+), 287 deletions(-) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/introspect/Introspect.kt b/src/main/kotlin/no/nav/security/mock/oauth2/introspect/Introspect.kt index 49b79a757..46dcdba00 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/introspect/Introspect.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/introspect/Introspect.kt @@ -3,15 +3,11 @@ package no.nav.security.mock.oauth2.introspect import com.fasterxml.jackson.annotation.JsonInclude import com.fasterxml.jackson.annotation.JsonProperty import com.nimbusds.jwt.JWTClaimsSet -import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.OAuth2Error -import com.nimbusds.oauth2.sdk.id.Issuer import mu.KotlinLogging import no.nav.security.mock.oauth2.OAuth2Exception import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.INTROSPECT -import no.nav.security.mock.oauth2.extensions.issuerId import no.nav.security.mock.oauth2.extensions.toIssuerUrl -import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer import no.nav.security.mock.oauth2.http.OAuth2HttpRequest import no.nav.security.mock.oauth2.http.Route import no.nav.security.mock.oauth2.http.json @@ -51,12 +47,10 @@ internal fun Route.Builder.introspect(tokenProvider: OAuth2TokenProvider) = } private fun OAuth2HttpRequest.verifyToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet? { - val tokenString = this.formParameters.get("token") - val issuer = url.toIssuerUrl() - val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId()) - val algorithm = tokenProvider.getAlgorithm() return try { - SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm) + this.formParameters.get("token")?.let { + tokenProvider.verify(url.toIssuerUrl(), it) + } } catch (e: Exception) { log.debug("token_introspection: failed signature validation") return null diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/KeyProvider.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/KeyProvider.kt index 8157de175..a9f6e8278 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/KeyProvider.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/KeyProvider.kt @@ -3,64 +3,73 @@ package no.nav.security.mock.oauth2.token import com.nimbusds.jose.JWSAlgorithm import com.nimbusds.jose.jwk.ECKey import com.nimbusds.jose.jwk.JWK +import com.nimbusds.jose.jwk.JWKSelector import com.nimbusds.jose.jwk.JWKSet import com.nimbusds.jose.jwk.KeyType import com.nimbusds.jose.jwk.RSAKey -import no.nav.security.mock.oauth2.OAuth2Exception +import com.nimbusds.jose.jwk.source.JWKSource +import com.nimbusds.jose.proc.SecurityContext import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.LinkedBlockingDeque +import no.nav.security.mock.oauth2.OAuth2Exception open class KeyProvider - @JvmOverloads - constructor( - private val initialKeys: List = keysFromFile(INITIAL_KEYS_FILE), - private val algorithm: String = JWSAlgorithm.RS256.name, - ) { - private val signingKeys: ConcurrentHashMap = ConcurrentHashMap() +@JvmOverloads +constructor( + private val initialKeys: List = keysFromFile(INITIAL_KEYS_FILE), + private val algorithm: String = JWSAlgorithm.RS256.name, +) : JWKSource { + private val signingKeys: ConcurrentHashMap = ConcurrentHashMap() - private var generator: KeyGenerator = KeyGenerator(JWSAlgorithm.parse(algorithm)) + private var generator: KeyGenerator = KeyGenerator(JWSAlgorithm.parse(algorithm)) - private val keyDeque = - LinkedBlockingDeque().apply { - initialKeys.forEach { - put(it) - } + private val keyDeque = + LinkedBlockingDeque().apply { + initialKeys.forEach { + put(it) } + } + + fun signingKey(keyId: String): JWK = signingKeys.computeIfAbsent(keyId) { keyFromDequeOrNew(keyId) } + + private fun keyFromDequeOrNew(keyId: String): JWK = + keyDeque.poll()?.let { polledJwk -> + when (polledJwk.keyType.value) { + KeyType.RSA.value -> { + RSAKey.Builder(polledJwk.toRSAKey()).keyID(keyId).build() + } - fun signingKey(keyId: String): JWK = signingKeys.computeIfAbsent(keyId) { keyFromDequeOrNew(keyId) } + KeyType.EC.value -> { + ECKey.Builder(polledJwk.toECKey()).keyID(keyId).build() + } - private fun keyFromDequeOrNew(keyId: String): JWK = - keyDeque.poll()?.let { polledJwk -> - when (polledJwk.keyType.value) { - KeyType.RSA.value -> { - RSAKey.Builder(polledJwk.toRSAKey()).keyID(keyId).build() - } - KeyType.EC.value -> { - ECKey.Builder(polledJwk.toECKey()).keyID(keyId).build() - } - else -> { - throw OAuth2Exception("Unsupported key type: ${polledJwk.keyType.value}") - } + else -> { + throw OAuth2Exception("Unsupported key type: ${polledJwk.keyType.value}") } - } ?: generator.generateKey(keyId) + } + } ?: generator.generateKey(keyId) - fun algorithm(): JWSAlgorithm = JWSAlgorithm.parse(algorithm) + fun algorithm(): JWSAlgorithm = JWSAlgorithm.parse(algorithm) - fun keyType(): String = generator.keyGenerator.algorithm + fun keyType(): String = generator.keyGenerator.algorithm - fun generate(algorithm: String) { - generator = KeyGenerator(JWSAlgorithm.parse(algorithm)) - } + fun generate(algorithm: String) { + generator = KeyGenerator(JWSAlgorithm.parse(algorithm)) + } - companion object { - const val INITIAL_KEYS_FILE = "/mock-oauth2-server-keys.json" + companion object { + const val INITIAL_KEYS_FILE = "/mock-oauth2-server-keys.json" - fun keysFromFile(filename: String): List { - val keysFromFile = KeyProvider::class.java.getResource(filename) - if (keysFromFile != null) { - return JWKSet.parse(keysFromFile.readText()).keys.map { it as JWK } - } - return emptyList() + fun keysFromFile(filename: String): List { + val keysFromFile = KeyProvider::class.java.getResource(filename) + if (keysFromFile != null) { + return JWKSet.parse(keysFromFile.readText()).keys.map { it as JWK } } + return emptyList() } } + + override fun get(jwkSelector: JWKSelector?, context: SecurityContext?): MutableList { + return signingKeys.values.toMutableList() + } +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt index b3a67390f..a0d89ab97 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt @@ -7,174 +7,204 @@ import com.nimbusds.jose.crypto.ECDSASigner import com.nimbusds.jose.crypto.RSASSASigner import com.nimbusds.jose.jwk.JWKSet import com.nimbusds.jose.jwk.KeyType +import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier +import com.nimbusds.jose.proc.JWSVerificationKeySelector +import com.nimbusds.jose.proc.SecurityContext import com.nimbusds.jwt.JWTClaimsSet import com.nimbusds.jwt.SignedJWT +import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier +import com.nimbusds.jwt.proc.DefaultJWTProcessor import com.nimbusds.oauth2.sdk.TokenRequest -import no.nav.security.mock.oauth2.OAuth2Exception -import no.nav.security.mock.oauth2.extensions.clientIdAsString -import no.nav.security.mock.oauth2.extensions.issuerId -import okhttp3.HttpUrl import java.time.Duration import java.time.Instant import java.util.Date import java.util.UUID +import no.nav.security.mock.oauth2.OAuth2Exception +import no.nav.security.mock.oauth2.extensions.clientIdAsString +import no.nav.security.mock.oauth2.extensions.issuerId +import okhttp3.HttpUrl typealias TimeProvider = () -> Instant? class OAuth2TokenProvider +@JvmOverloads +constructor( + private val keyProvider: KeyProvider = KeyProvider(), + private val timeProvider: TimeProvider, +) { + val systemTime + get() = timeProvider() + @JvmOverloads constructor( - private val keyProvider: KeyProvider = KeyProvider(), - private val timeProvider: TimeProvider, - ) { - val systemTime - get() = timeProvider() - - @JvmOverloads - constructor( - keyProvider: KeyProvider = KeyProvider(), - systemTime: Instant? = null, - ) : this(keyProvider, { systemTime }) - - @JvmOverloads - fun publicJwkSet(issuerId: String = "default"): JWKSet { - return JWKSet(keyProvider.signingKey(issuerId)).toPublicJWKSet() - } + keyProvider: KeyProvider = KeyProvider(), + systemTime: Instant? = null, + ) : this(keyProvider, { systemTime }) - fun getAlgorithm(): JWSAlgorithm { - return keyProvider.algorithm() - } + @JvmOverloads + fun publicJwkSet(issuerId: String = "default"): JWKSet { + return JWKSet(keyProvider.signingKey(issuerId)).toPublicJWKSet() + } - fun idToken( - tokenRequest: TokenRequest, - issuerUrl: HttpUrl, - oAuth2TokenCallback: OAuth2TokenCallback, - nonce: String? = null, - ) = defaultClaims( - issuerUrl, - oAuth2TokenCallback.subject(tokenRequest), - listOf(tokenRequest.clientIdAsString()), - nonce, - oAuth2TokenCallback.addClaims(tokenRequest), - oAuth2TokenCallback.tokenExpiry(), - ).sign(issuerUrl.issuerId(), oAuth2TokenCallback.typeHeader(tokenRequest)) - - fun accessToken( - tokenRequest: TokenRequest, - issuerUrl: HttpUrl, - oAuth2TokenCallback: OAuth2TokenCallback, - nonce: String? = null, - ) = defaultClaims( - issuerUrl, - oAuth2TokenCallback.subject(tokenRequest), - oAuth2TokenCallback.audience(tokenRequest), - nonce, - oAuth2TokenCallback.addClaims(tokenRequest), - oAuth2TokenCallback.tokenExpiry(), - ).sign(issuerUrl.issuerId(), oAuth2TokenCallback.typeHeader(tokenRequest)) - - fun exchangeAccessToken( - tokenRequest: TokenRequest, - issuerUrl: HttpUrl, - claimsSet: JWTClaimsSet, - oAuth2TokenCallback: OAuth2TokenCallback, - ) = systemTime.orNow().let { now -> - JWTClaimsSet.Builder(claimsSet) - .issuer(issuerUrl.toString()) - .expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry()))) - .notBeforeTime(Date.from(now)) + fun getAlgorithm(): JWSAlgorithm { + return keyProvider.algorithm() + } + + fun idToken( + tokenRequest: TokenRequest, + issuerUrl: HttpUrl, + oAuth2TokenCallback: OAuth2TokenCallback, + nonce: String? = null, + ) = defaultClaims( + issuerUrl, + oAuth2TokenCallback.subject(tokenRequest), + listOf(tokenRequest.clientIdAsString()), + nonce, + oAuth2TokenCallback.addClaims(tokenRequest), + oAuth2TokenCallback.tokenExpiry(), + ).sign(issuerUrl.issuerId(), oAuth2TokenCallback.typeHeader(tokenRequest)) + + fun accessToken( + tokenRequest: TokenRequest, + issuerUrl: HttpUrl, + oAuth2TokenCallback: OAuth2TokenCallback, + nonce: String? = null, + ) = defaultClaims( + issuerUrl, + oAuth2TokenCallback.subject(tokenRequest), + oAuth2TokenCallback.audience(tokenRequest), + nonce, + oAuth2TokenCallback.addClaims(tokenRequest), + oAuth2TokenCallback.tokenExpiry(), + ).sign(issuerUrl.issuerId(), oAuth2TokenCallback.typeHeader(tokenRequest)) + + fun exchangeAccessToken( + tokenRequest: TokenRequest, + issuerUrl: HttpUrl, + claimsSet: JWTClaimsSet, + oAuth2TokenCallback: OAuth2TokenCallback, + ) = systemTime.orNow().let { now -> + JWTClaimsSet.Builder(claimsSet) + .issuer(issuerUrl.toString()) + .expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry()))) + .notBeforeTime(Date.from(now)) + .issueTime(Date.from(now)) + .jwtID(UUID.randomUUID().toString()) + .audience(oAuth2TokenCallback.audience(tokenRequest)) + .addClaims(oAuth2TokenCallback.addClaims(tokenRequest)) + .build() + .sign(issuerUrl.issuerId(), oAuth2TokenCallback.typeHeader(tokenRequest)) + } + + @JvmOverloads + fun jwt( + claims: Map, + expiry: Duration = Duration.ofHours(1), + issuerId: String = "default", + ): SignedJWT = + JWTClaimsSet.Builder().let { builder -> + val now = systemTime.orNow() + builder .issueTime(Date.from(now)) - .jwtID(UUID.randomUUID().toString()) - .audience(oAuth2TokenCallback.audience(tokenRequest)) - .addClaims(oAuth2TokenCallback.addClaims(tokenRequest)) - .build() - .sign(issuerUrl.issuerId(), oAuth2TokenCallback.typeHeader(tokenRequest)) - } + .notBeforeTime(Date.from(now)) + .expirationTime(Date.from(now.plusSeconds(expiry.toSeconds()))) + builder.addClaims(claims) + builder.build() + }.sign(issuerId, JOSEObjectType.JWT.type) - @JvmOverloads - fun jwt( - claims: Map, - expiry: Duration = Duration.ofHours(1), - issuerId: String = "default", - ): SignedJWT = - JWTClaimsSet.Builder().let { builder -> - val now = systemTime.orNow() - builder - .issueTime(Date.from(now)) - .notBeforeTime(Date.from(now)) - .expirationTime(Date.from(now.plusSeconds(expiry.toSeconds()))) - builder.addClaims(claims) - builder.build() - }.sign(issuerId, JOSEObjectType.JWT.type) - - private fun JWTClaimsSet.sign( - issuerId: String, - type: String, - ): SignedJWT { - val key = keyProvider.signingKey(issuerId) - val algorithm = keyProvider.algorithm() - val keyType = keyProvider.keyType() - val supported = KeyGenerator.isSupported(algorithm) - - return when { - supported && keyType == KeyType.RSA.value -> { - SignedJWT( - jwsHeader(key.keyID, type, algorithm), - this, - ).apply { - sign(RSASSASigner(key.toRSAKey().toPrivateKey())) - } - } - supported && keyType == KeyType.EC.value -> { - SignedJWT( - jwsHeader(key.keyID, type, algorithm), - this, - ).apply { - sign(ECDSASigner(key.toECKey().toECPrivateKey())) - } - } - else -> { - throw OAuth2Exception("Unsupported algorithm: ${algorithm.name}") + fun verify( + issuerUrl: HttpUrl, + token: String, + ): JWTClaimsSet { + return SignedJWT.parse(token).verify(issuerUrl) + } + + private fun JWTClaimsSet.sign( + issuerId: String, + type: String, + ): SignedJWT { + val key = keyProvider.signingKey(issuerId) + val algorithm = keyProvider.algorithm() + val keyType = keyProvider.keyType() + val supported = KeyGenerator.isSupported(algorithm) + + return when { + supported && keyType == KeyType.RSA.value -> { + SignedJWT( + jwsHeader(key.keyID, type, algorithm), + this, + ).apply { + sign(RSASSASigner(key.toRSAKey().toPrivateKey())) } } - } - private fun jwsHeader( - keyId: String, - type: String, - algorithm: JWSAlgorithm, - ): JWSHeader { - return JWSHeader.Builder(algorithm) - .keyID(keyId) - .type(JOSEObjectType(type)).build() - } + supported && keyType == KeyType.EC.value -> { + SignedJWT( + jwsHeader(key.keyID, type, algorithm), + this, + ).apply { + sign(ECDSASigner(key.toECKey().toECPrivateKey())) + } + } - private fun JWTClaimsSet.Builder.addClaims(claims: Map = emptyMap()) = - apply { - claims.forEach { this.claim(it.key, it.value) } + else -> { + throw OAuth2Exception("Unsupported algorithm: ${algorithm.name}") } + } + } - private fun defaultClaims( - issuerUrl: HttpUrl, - subject: String?, - audience: List, - nonce: String?, - additionalClaims: Map, - expiry: Long, - ) = JWTClaimsSet.Builder().let { builder -> - val now = systemTime.orNow() - builder.subject(subject) - .audience(audience) - .issuer(issuerUrl.toString()) - .issueTime(Date.from(now)) - .notBeforeTime(Date.from(now)) - .expirationTime(Date.from(now.plusSeconds(expiry))) - .jwtID(UUID.randomUUID().toString()) + private fun jwsHeader( + keyId: String, + type: String, + algorithm: JWSAlgorithm, + ): JWSHeader { + return JWSHeader.Builder(algorithm) + .keyID(keyId) + .type(JOSEObjectType(type)).build() + } - nonce?.also { builder.claim("nonce", it) } - builder.addClaims(additionalClaims) - builder.build() + private fun JWTClaimsSet.Builder.addClaims(claims: Map = emptyMap()) = + apply { + claims.forEach { this.claim(it.key, it.value) } } - private fun Instant?.orNow(): Instant = this ?: Instant.now() + private fun defaultClaims( + issuerUrl: HttpUrl, + subject: String?, + audience: List, + nonce: String?, + additionalClaims: Map, + expiry: Long, + ) = JWTClaimsSet.Builder().let { builder -> + val now = systemTime.orNow() + builder.subject(subject) + .audience(audience) + .issuer(issuerUrl.toString()) + .issueTime(Date.from(now)) + .notBeforeTime(Date.from(now)) + .expirationTime(Date.from(now.plusSeconds(expiry))) + .jwtID(UUID.randomUUID().toString()) + + nonce?.also { builder.claim("nonce", it) } + builder.addClaims(additionalClaims) + builder.build() + } + + private fun Instant?.orNow(): Instant = this ?: Instant.now() + + private fun SignedJWT.verify(issuerUrl: HttpUrl): JWTClaimsSet { + val jwtProcessor = DefaultJWTProcessor().apply { + jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT")) + jwsKeySelector = JWSVerificationKeySelector(keyProvider.algorithm(), keyProvider) + jwtClaimsSetVerifier = object : DefaultJWTClaimsVerifier( + JWTClaimsSet.Builder().issuer(issuerUrl.toString()).build(), + HashSet(listOf("iat", "exp")), + ) { + override fun currentTime(): Date { + return Date.from(timeProvider()) + } + } + } + return jwtProcessor.process(this, null) } +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfo.kt b/src/main/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfo.kt index d7cce6eda..343a057e7 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfo.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfo.kt @@ -1,16 +1,12 @@ package no.nav.security.mock.oauth2.userinfo import com.nimbusds.jwt.JWTClaimsSet -import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.ErrorObject import com.nimbusds.oauth2.sdk.http.HTTPResponse -import com.nimbusds.oauth2.sdk.id.Issuer import mu.KotlinLogging import no.nav.security.mock.oauth2.OAuth2Exception import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.USER_INFO -import no.nav.security.mock.oauth2.extensions.issuerId import no.nav.security.mock.oauth2.extensions.toIssuerUrl -import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer import no.nav.security.mock.oauth2.http.OAuth2HttpRequest import no.nav.security.mock.oauth2.http.Route import no.nav.security.mock.oauth2.http.json @@ -27,12 +23,8 @@ internal fun Route.Builder.userInfo(tokenProvider: OAuth2TokenProvider) = } private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet { - val tokenString = this.headers.bearerToken() - val issuer = url.toIssuerUrl() - val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId()) - val algorithm = tokenProvider.getAlgorithm() return try { - SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm) + tokenProvider.verify(url.toIssuerUrl(), this.headers.bearerToken()) } catch (e: Exception) { throw invalidToken(e.message ?: "could not verify bearer token") } diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/introspect/IntrospectTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/introspect/IntrospectTest.kt index dc288a720..0112d8988 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/introspect/IntrospectTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/introspect/IntrospectTest.kt @@ -9,6 +9,8 @@ import io.kotest.matchers.maps.shouldContain import io.kotest.matchers.maps.shouldContainAll import io.kotest.matchers.maps.shouldContainExactly import io.kotest.matchers.shouldBe +import java.time.Instant +import java.time.temporal.ChronoUnit import no.nav.security.mock.oauth2.OAuth2Exception import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.INTROSPECT import no.nav.security.mock.oauth2.http.OAuth2HttpRequest @@ -66,6 +68,29 @@ internal class IntrospectTest { } } + @Test + fun `introspect should return active and claims from token when using a custom timeProvider in the OAuth2TokenProvider`() { + val issuerUrl = "http://localhost/default" + val yesterday = Instant.now().minus(1, ChronoUnit.DAYS) + val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday }) + val claims = + mapOf( + "iss" to issuerUrl, + "client_id" to "yolo", + "token_type" to "token", + "sub" to "foo", + ) + val token = tokenProvider.jwt(claims) + val request = request("$issuerUrl$INTROSPECT", token.serialize()) + + routes { introspect(tokenProvider) }.invoke(request).asClue { + it.status shouldBe 200 + val response = it.parse>() + response shouldContainAll claims + response shouldContain ("active" to true) + } + } + @Test fun `introspect should return active false when token is missing`() { val url = "http://localhost/default$INTROSPECT" diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderRSATest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderRSATest.kt index 9235f2493..c2e4aa5b3 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderRSATest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderRSATest.kt @@ -10,17 +10,15 @@ import io.kotest.assertions.asClue import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe +import java.time.Instant +import java.time.temporal.ChronoUnit +import java.util.Date import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer import no.nav.security.mock.oauth2.testutils.nimbusTokenRequest import okhttp3.HttpUrl.Companion.toHttpUrl import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource -import java.time.Clock -import java.time.Instant -import java.time.ZoneId -import java.time.temporal.ChronoUnit -import java.util.Date internal class OAuth2TokenProviderRSATest { private val tokenProvider = OAuth2TokenProvider() @@ -55,22 +53,22 @@ internal class OAuth2TokenProviderRSATest { tokenProvider.exchangeAccessToken( tokenRequest = - nimbusTokenRequest( - "myclient", - "grant_type" to GrantType.JWT_BEARER.value, - "scope" to "scope1", - "assertion" to initialToken.serialize(), - ), + nimbusTokenRequest( + "myclient", + "grant_type" to GrantType.JWT_BEARER.value, + "scope" to "scope1", + "assertion" to initialToken.serialize(), + ), issuerUrl = "http://default_if_not_overridden".toHttpUrl(), claimsSet = initialToken.jwtClaimsSet, oAuth2TokenCallback = - DefaultOAuth2TokenCallback( - claims = - mapOf( - "extraclaim" to "extra", - "iss" to "http://overrideissuer", - ), + DefaultOAuth2TokenCallback( + claims = + mapOf( + "extraclaim" to "extra", + "iss" to "http://overrideissuer", ), + ), ).jwtClaimsSet.asClue { it.issuer shouldBe "http://overrideissuer" it.subject shouldBe "initialsubject" @@ -104,95 +102,81 @@ internal class OAuth2TokenProviderRSATest { val yesterday = Instant.now().minus(1, ChronoUnit.DAYS) val tokenProvider = OAuth2TokenProvider(systemTime = yesterday) - tokenProvider.exchangeAccessToken( - tokenRequest = - nimbusTokenRequest( - "id", - "grant_type" to GrantType.CLIENT_CREDENTIALS.value, - "scope" to "scope1", - ), - issuerUrl = "http://default_if_not_overridden".toHttpUrl(), - claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet, - oAuth2TokenCallback = DefaultOAuth2TokenCallback(), - ).asClue { + tokenProvider.clientCredentialsToken("http://localhost/default").asClue { it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime) - println(it.serialize()) + } + + val now = Instant.now() + OAuth2TokenProvider().clientCredentialsToken("http://localhost/default").asClue { + it.jwtClaimsSet.issueTime shouldBeAfter now } } @Test fun `token should have issuedAt set dynamically according to timeProvider`() { - val clock = - object : Clock() { - private var clock = systemDefaultZone() - - override fun instant() = clock.instant() - - override fun withZone(zone: ZoneId) = clock.withZone(zone) - - override fun getZone() = clock.zone - - fun fixed(instant: Instant) { - clock = fixed(instant, zone) - } - } + val timeProvider = object : TimeProvider { + var time = Instant.now() + override fun invoke(): Instant = time + } - val tokenProvider = OAuth2TokenProvider { clock.instant() } + val tokenProvider = OAuth2TokenProvider(timeProvider = timeProvider) val instant1 = Instant.parse("2000-12-03T10:15:30.00Z") val instant2 = Instant.parse("2020-01-21T00:00:00.00Z") - instant1 shouldNotBe instant2 - - run { - clock.fixed(instant1) - tokenProvider.systemTime shouldBe instant1 - - tokenProvider.exchangeAccessToken( - tokenRequest = - nimbusTokenRequest( - "id", - "grant_type" to GrantType.CLIENT_CREDENTIALS.value, - "scope" to "scope1", - ), - issuerUrl = "http://default_if_not_overridden".toHttpUrl(), - claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet, - oAuth2TokenCallback = DefaultOAuth2TokenCallback(), - ) - }.asClue { + + timeProvider.time = instant1 + tokenProvider.systemTime shouldBe instant1 + + tokenProvider.clientCredentialsToken("http://localhost/default").asClue { it.jwtClaimsSet.issueTime shouldBe Date.from(instant1) - println(it.serialize()) } - run { - clock.fixed(instant2) - tokenProvider.systemTime shouldBe instant2 - - tokenProvider.exchangeAccessToken( - tokenRequest = - nimbusTokenRequest( - "id", - "grant_type" to GrantType.CLIENT_CREDENTIALS.value, - "scope" to "scope1", - ), - issuerUrl = "http://default_if_not_overridden".toHttpUrl(), - claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet, - oAuth2TokenCallback = DefaultOAuth2TokenCallback(), - ) - }.asClue { + timeProvider.time = instant2 + tokenProvider.systemTime shouldBe instant2 + + tokenProvider.clientCredentialsToken("http://localhost/default").asClue { it.jwtClaimsSet.issueTime shouldBe Date.from(instant2) - println(it.serialize()) } } + @Test + fun `token with issueTime set to yesterday should be able to validate with the verify function using the same timeprovider`() { + val yesterday = Instant.now().minus(1, ChronoUnit.DAYS) + val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday }) + + val token = tokenProvider.clientCredentialsToken("http://localhost/default") + + token.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime) + + tokenProvider.verify("http://localhost/default".toHttpUrl(), token.serialize()).toJSONObject().asClue { + it shouldBe token.jwtClaimsSet.toJSONObject() + } + } + + private fun OAuth2TokenProvider.clientCredentialsToken(issuerUrl: String): SignedJWT = + accessToken( + tokenRequest = + nimbusTokenRequest( + "client1", + "grant_type" to "client_credentials", + "scope" to "scope1", + ), + issuerUrl = issuerUrl.toHttpUrl(), + oAuth2TokenCallback = DefaultOAuth2TokenCallback(), + ) + private fun idToken(issuerUrl: String): SignedJWT = tokenProvider.idToken( tokenRequest = - nimbusTokenRequest( - "client1", - "grant_type" to "authorization_code", - "code" to "123", - ), + nimbusTokenRequest( + "client1", + "grant_type" to "authorization_code", + "code" to "123", + ), issuerUrl = issuerUrl.toHttpUrl(), oAuth2TokenCallback = DefaultOAuth2TokenCallback(), ) + + private infix fun Date.shouldBeAfter(instant: Instant?) = this.after(Date.from(instant)) shouldBe true } + diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfoTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfoTest.kt index 761f6be02..83c702d5b 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfoTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/userinfo/UserInfoTest.kt @@ -7,6 +7,8 @@ import io.kotest.assertions.asClue import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.maps.shouldContainAll import io.kotest.matchers.shouldBe +import java.time.Instant +import java.time.temporal.ChronoUnit import no.nav.security.mock.oauth2.OAuth2Exception import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.USER_INFO import no.nav.security.mock.oauth2.http.OAuth2HttpRequest @@ -38,6 +40,26 @@ internal class UserInfoTest { } } + @Test + fun `userinfo should return claims from bearer token when using a custom timeProvider in OAuth2TokenProvider`() { + val issuerUrl = "http://localhost/default" + val yesterday = Instant.now().minus(1, ChronoUnit.DAYS) + val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday }) + val claims = + mapOf( + "iss" to issuerUrl, + "sub" to "foo", + "extra" to "bar", + ) + val bearerToken = tokenProvider.jwt(claims) + val request = request("$issuerUrl$USER_INFO", bearerToken.serialize()) + + routes { userInfo(tokenProvider) }.invoke(request).asClue { + it.status shouldBe 200 + it.parse>() shouldContainAll claims + } + } + @Test fun `userinfo should throw OAuth2Exception when algorithm does not match`() { val issuerUrl = "http://localhost/default"