Skip to content

Commit

Permalink
feat: support custom TimeProvider when validating tokens
Browse files Browse the repository at this point in the history
* add verify function to OAuth2TokenProvider and use the TimeProvider if set - i.e. via overriding Nimbus DefaultJWTClaimsVerifier's currentTime function
* refactor tests for simplicity
  • Loading branch information
tommytroen committed Aug 20, 2024
1 parent 4aab3a5 commit 36ca7a8
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 83 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ dependencies {
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion")
implementation("org.freemarker:freemarker:$freemarkerVersion")
implementation("org.bouncycastle:bcpkix-jdk18on:$bouncyCastleVersion")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.7.0")
testImplementation("org.assertj:assertj-core:$assertjVersion")
testImplementation("org.junit.jupiter:junit-jupiter-api:$junitJupiterVersion")
testImplementation("org.junit.jupiter:junit-jupiter-params:$junitJupiterVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ package no.nav.security.mock.oauth2.introspect
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonProperty
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.OAuth2Error
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.INTROSPECT
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
Expand Down Expand Up @@ -51,12 +47,10 @@ internal fun Route.Builder.introspect(tokenProvider: OAuth2TokenProvider) =
}

private fun OAuth2HttpRequest.verifyToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet? {
val tokenString = this.formParameters.get("token")
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
this.formParameters.get("token")?.let {
tokenProvider.verify(url.toIssuerUrl(), it)
}
} catch (e: Exception) {
log.debug("token_introspection: failed signature validation")
return null
Expand Down
14 changes: 13 additions & 1 deletion src/main/kotlin/no/nav/security/mock/oauth2/token/KeyProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package no.nav.security.mock.oauth2.token
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.ECKey
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.JWKSelector
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.jwk.RSAKey
import com.nimbusds.jose.jwk.source.JWKSource
import com.nimbusds.jose.proc.SecurityContext
import no.nav.security.mock.oauth2.OAuth2Exception
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingDeque
Expand All @@ -15,7 +18,7 @@ open class KeyProvider
constructor(
private val initialKeys: List<JWK> = keysFromFile(INITIAL_KEYS_FILE),
private val algorithm: String = JWSAlgorithm.RS256.name,
) {
) : JWKSource<SecurityContext> {
private val signingKeys: ConcurrentHashMap<String, JWK> = ConcurrentHashMap()

private var generator: KeyGenerator = KeyGenerator(JWSAlgorithm.parse(algorithm))
Expand All @@ -35,9 +38,11 @@ open class KeyProvider
KeyType.RSA.value -> {
RSAKey.Builder(polledJwk.toRSAKey()).keyID(keyId).build()
}

KeyType.EC.value -> {
ECKey.Builder(polledJwk.toECKey()).keyID(keyId).build()
}

else -> {
throw OAuth2Exception("Unsupported key type: ${polledJwk.keyType.value}")
}
Expand All @@ -63,4 +68,11 @@ open class KeyProvider
return emptyList()
}
}

override fun get(
jwkSelector: JWKSelector?,
context: SecurityContext?,
): MutableList<JWK> {
return signingKeys.values.toMutableList()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ import com.nimbusds.jose.crypto.ECDSASigner
import com.nimbusds.jose.crypto.RSASSASigner
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier
import com.nimbusds.jose.proc.JWSVerificationKeySelector
import com.nimbusds.jose.proc.SecurityContext
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier
import com.nimbusds.jwt.proc.DefaultJWTProcessor
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.clientIdAsString
Expand Down Expand Up @@ -107,6 +112,13 @@ class OAuth2TokenProvider
builder.build()
}.sign(issuerId, JOSEObjectType.JWT.type)

fun verify(
issuerUrl: HttpUrl,
token: String,
): JWTClaimsSet {
return SignedJWT.parse(token).verify(issuerUrl)
}

private fun JWTClaimsSet.sign(
issuerId: String,
type: String,
Expand All @@ -125,6 +137,7 @@ class OAuth2TokenProvider
sign(RSASSASigner(key.toRSAKey().toPrivateKey()))
}
}

supported && keyType == KeyType.EC.value -> {
SignedJWT(
jwsHeader(key.keyID, type, algorithm),
Expand All @@ -133,6 +146,7 @@ class OAuth2TokenProvider
sign(ECDSASigner(key.toECKey().toECPrivateKey()))
}
}

else -> {
throw OAuth2Exception("Unsupported algorithm: ${algorithm.name}")
}
Expand Down Expand Up @@ -177,4 +191,22 @@ class OAuth2TokenProvider
}

private fun Instant?.orNow(): Instant = this ?: Instant.now()

private fun SignedJWT.verify(issuerUrl: HttpUrl): JWTClaimsSet {
val jwtProcessor =
DefaultJWTProcessor<SecurityContext?>().apply {
jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT"))
jwsKeySelector = JWSVerificationKeySelector(keyProvider.algorithm(), keyProvider)
jwtClaimsSetVerifier =
object : DefaultJWTClaimsVerifier<SecurityContext?>(
JWTClaimsSet.Builder().issuer(issuerUrl.toString()).build(),
HashSet(listOf("iat", "exp")),
) {
override fun currentTime(): Date {
return Date.from(timeProvider())
}
}
}
return jwtProcessor.process(this, null)
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package no.nav.security.mock.oauth2.userinfo

import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.ErrorObject
import com.nimbusds.oauth2.sdk.http.HTTPResponse
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.USER_INFO
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
Expand All @@ -27,12 +23,8 @@ internal fun Route.Builder.userInfo(tokenProvider: OAuth2TokenProvider) =
}

private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet {
val tokenString = this.headers.bearerToken()
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
tokenProvider.verify(url.toIssuerUrl(), this.headers.bearerToken())
} catch (e: Exception) {
throw invalidToken(e.message ?: "could not verify bearer token")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import java.time.Instant
import java.time.temporal.ChronoUnit

internal class IntrospectTest {
private val rs384TokenProvider = OAuth2TokenProvider(keyProvider = KeyProvider(initialKeys = emptyList(), algorithm = JWSAlgorithm.RS384.name))
Expand Down Expand Up @@ -66,6 +68,29 @@ internal class IntrospectTest {
}
}

@Test
fun `introspect should return active and claims from token when using a custom timeProvider in the OAuth2TokenProvider`() {
val issuerUrl = "http://localhost/default"
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })
val claims =
mapOf(
"iss" to issuerUrl,
"client_id" to "yolo",
"token_type" to "token",
"sub" to "foo",
)
val token = tokenProvider.jwt(claims)
val request = request("$issuerUrl$INTROSPECT", token.serialize())

routes { introspect(tokenProvider) }.invoke(request).asClue {
it.status shouldBe 200
val response = it.parse<Map<String, Any>>()
response shouldContainAll claims
response shouldContain ("active" to true)
}
}

@Test
fun `introspect should return active false when token is missing`() {
val url = "http://localhost/default$INTROSPECT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.time.Clock
import java.time.Instant
import java.time.ZoneId
import java.time.temporal.ChronoUnit
import java.util.Date

Expand Down Expand Up @@ -104,86 +102,71 @@ internal class OAuth2TokenProviderRSATest {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(systemTime = yesterday)

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
).asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
println(it.serialize())
}

val now = Instant.now()
OAuth2TokenProvider().clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBeAfter now
}
}

@Test
fun `token should have issuedAt set dynamically according to timeProvider`() {
val clock =
object : Clock() {
private var clock = systemDefaultZone()
val timeProvider =
object : TimeProvider {
var time = Instant.now()

override fun instant() = clock.instant()

override fun withZone(zone: ZoneId) = clock.withZone(zone)

override fun getZone() = clock.zone

fun fixed(instant: Instant) {
clock = fixed(instant, zone)
}
override fun invoke(): Instant = time
}

val tokenProvider = OAuth2TokenProvider { clock.instant() }
val tokenProvider = OAuth2TokenProvider(timeProvider = timeProvider)

val instant1 = Instant.parse("2000-12-03T10:15:30.00Z")
val instant2 = Instant.parse("2020-01-21T00:00:00.00Z")
instant1 shouldNotBe instant2

run {
clock.fixed(instant1)
tokenProvider.systemTime shouldBe instant1

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {

timeProvider.time = instant1
tokenProvider.systemTime shouldBe instant1

tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant1)
println(it.serialize())
}

run {
clock.fixed(instant2)
tokenProvider.systemTime shouldBe instant2

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
timeProvider.time = instant2
tokenProvider.systemTime shouldBe instant2

tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant2)
println(it.serialize())
}
}

@Test
fun `token with issueTime set to yesterday should be able to validate with the verify function using the same timeprovider`() {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })

val token = tokenProvider.clientCredentialsToken("http://localhost/default")

token.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)

tokenProvider.verify("http://localhost/default".toHttpUrl(), token.serialize()).toJSONObject().asClue {
it shouldBe token.jwtClaimsSet.toJSONObject()
}
}

private fun OAuth2TokenProvider.clientCredentialsToken(issuerUrl: String): SignedJWT =
accessToken(
tokenRequest =
nimbusTokenRequest(
"client1",
"grant_type" to "client_credentials",
"scope" to "scope1",
),
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private fun idToken(issuerUrl: String): SignedJWT =
tokenProvider.idToken(
tokenRequest =
Expand All @@ -195,4 +178,6 @@ internal class OAuth2TokenProviderRSATest {
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private infix fun Date.shouldBeAfter(instant: Instant?) = this.after(Date.from(instant)) shouldBe true
}
Loading

0 comments on commit 36ca7a8

Please sign in to comment.