Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom TimeProvider when validating tokens (introspect, userinfo) #730

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ package no.nav.security.mock.oauth2.introspect
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonProperty
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.OAuth2Error
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.INTROSPECT
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
Expand Down Expand Up @@ -51,12 +47,10 @@ internal fun Route.Builder.introspect(tokenProvider: OAuth2TokenProvider) =
}

private fun OAuth2HttpRequest.verifyToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet? {
val tokenString = this.formParameters.get("token")
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
this.formParameters.get("token")?.let {
tokenProvider.verify(url.toIssuerUrl(), it)
}
} catch (e: Exception) {
log.debug("token_introspection: failed signature validation")
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package no.nav.security.mock.oauth2.token
import com.nimbusds.jose.JWSAlgorithm
import com.nimbusds.jose.jwk.ECKey
import com.nimbusds.jose.jwk.JWK
import com.nimbusds.jose.jwk.JWKSelector
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.jwk.RSAKey
import com.nimbusds.jose.jwk.source.JWKSource
import com.nimbusds.jose.proc.SecurityContext
import no.nav.security.mock.oauth2.OAuth2Exception
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.LinkedBlockingDeque
Expand All @@ -15,7 +18,7 @@ open class KeyProvider
constructor(
private val initialKeys: List<JWK> = keysFromFile(INITIAL_KEYS_FILE),
private val algorithm: String = JWSAlgorithm.RS256.name,
) {
) : JWKSource<SecurityContext> {
private val signingKeys: ConcurrentHashMap<String, JWK> = ConcurrentHashMap()

private var generator: KeyGenerator = KeyGenerator(JWSAlgorithm.parse(algorithm))
Expand All @@ -35,9 +38,11 @@ open class KeyProvider
KeyType.RSA.value -> {
RSAKey.Builder(polledJwk.toRSAKey()).keyID(keyId).build()
}

KeyType.EC.value -> {
ECKey.Builder(polledJwk.toECKey()).keyID(keyId).build()
}

else -> {
throw OAuth2Exception("Unsupported key type: ${polledJwk.keyType.value}")
}
Expand All @@ -63,4 +68,10 @@ open class KeyProvider
return emptyList()
}
}

override fun get(
jwkSelector: JWKSelector?,
context: SecurityContext?,
): MutableList<JWK> = jwkSelector?.select(JWKSet(signingKeys.values.toList()).toPublicJWKSet()) ?: mutableListOf()

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ import com.nimbusds.jose.crypto.ECDSASigner
import com.nimbusds.jose.crypto.RSASSASigner
import com.nimbusds.jose.jwk.JWKSet
import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier
import com.nimbusds.jose.proc.JWSVerificationKeySelector
import com.nimbusds.jose.proc.SecurityContext
import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier
import com.nimbusds.jwt.proc.DefaultJWTProcessor
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.clientIdAsString
Expand Down Expand Up @@ -106,6 +111,11 @@ class OAuth2TokenProvider
builder.build()
}.sign(issuerId, JOSEObjectType.JWT.type)

fun verify(
issuerUrl: HttpUrl,
token: String,
): JWTClaimsSet = SignedJWT.parse(token).verify(issuerUrl)

private fun JWTClaimsSet.sign(
issuerId: String,
type: String,
Expand All @@ -124,6 +134,7 @@ class OAuth2TokenProvider
sign(RSASSASigner(key.toRSAKey().toPrivateKey()))
}
}

supported && keyType == KeyType.EC.value -> {
SignedJWT(
jwsHeader(key.keyID, type, algorithm),
Expand All @@ -132,6 +143,7 @@ class OAuth2TokenProvider
sign(ECDSASigner(key.toECKey().toECPrivateKey()))
}
}

else -> {
throw OAuth2Exception("Unsupported algorithm: ${algorithm.name}")
}
Expand Down Expand Up @@ -178,4 +190,20 @@ class OAuth2TokenProvider
}

private fun Instant?.orNow(): Instant = this ?: Instant.now()

private fun SignedJWT.verify(issuerUrl: HttpUrl): JWTClaimsSet {
val jwtProcessor =
DefaultJWTProcessor<SecurityContext?>().apply {
jwsTypeVerifier = DefaultJOSEObjectTypeVerifier(JOSEObjectType("JWT"))
jwsKeySelector = JWSVerificationKeySelector(keyProvider.algorithm(), keyProvider)
jwtClaimsSetVerifier =
object : DefaultJWTClaimsVerifier<SecurityContext?>(
JWTClaimsSet.Builder().issuer(issuerUrl.toString()).build(),
HashSet(listOf("iat", "exp")),
) {
override fun currentTime(): Date = Date.from(timeProvider().orNow())
}
}
return jwtProcessor.process(this, null)
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package no.nav.security.mock.oauth2.userinfo

import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.ErrorObject
import com.nimbusds.oauth2.sdk.http.HTTPResponse
import com.nimbusds.oauth2.sdk.id.Issuer
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.OAuth2Endpoints.USER_INFO
import no.nav.security.mock.oauth2.extensions.issuerId
import no.nav.security.mock.oauth2.extensions.toIssuerUrl
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.http.OAuth2HttpRequest
import no.nav.security.mock.oauth2.http.Route
import no.nav.security.mock.oauth2.http.json
Expand All @@ -26,17 +22,12 @@ internal fun Route.Builder.userInfo(tokenProvider: OAuth2TokenProvider) =
json(claims)
}

private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet {
val tokenString = this.headers.bearerToken()
val issuer = url.toIssuerUrl()
val jwkSet = tokenProvider.publicJwkSet(issuer.issuerId())
val algorithm = tokenProvider.getAlgorithm()
return try {
SignedJWT.parse(tokenString).verifySignatureAndIssuer(Issuer(issuer.toString()), jwkSet, algorithm)
private fun OAuth2HttpRequest.verifyBearerToken(tokenProvider: OAuth2TokenProvider): JWTClaimsSet =
try {
tokenProvider.verify(url.toIssuerUrl(), this.headers.bearerToken())
} catch (e: Exception) {
throw invalidToken(e.message ?: "could not verify bearer token")
}
}

private fun Headers.bearerToken(): String =
this["Authorization"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import java.time.Instant
import java.time.temporal.ChronoUnit

internal class IntrospectTest {
private val rs384TokenProvider = OAuth2TokenProvider(keyProvider = KeyProvider(initialKeys = emptyList(), algorithm = JWSAlgorithm.RS384.name))
Expand Down Expand Up @@ -66,6 +68,29 @@ internal class IntrospectTest {
}
}

@Test
fun `introspect should return active and claims from token when using a custom timeProvider in the OAuth2TokenProvider`() {
val issuerUrl = "http://localhost/default"
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })
val claims =
mapOf(
"iss" to issuerUrl,
"client_id" to "yolo",
"token_type" to "token",
"sub" to "foo",
)
val token = tokenProvider.jwt(claims)
val request = request("$issuerUrl$INTROSPECT", token.serialize())

routes { introspect(tokenProvider) }.invoke(request).asClue {
it.status shouldBe 200
val response = it.parse<Map<String, Any>>()
response shouldContainAll claims
response shouldContain ("active" to true)
}
}

@Test
fun `introspect should return active false when token is missing`() {
val url = "http://localhost/default$INTROSPECT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.time.Clock
import java.time.Instant
import java.time.ZoneId
import java.time.temporal.ChronoUnit
import java.util.Date

Expand Down Expand Up @@ -106,87 +104,71 @@ internal class OAuth2TokenProviderRSATest {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(systemTime = yesterday)

tokenProvider
.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
).asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
println(it.serialize())
}
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)
}

val now = Instant.now().minus(1, ChronoUnit.SECONDS)
OAuth2TokenProvider().clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBeAfter now
}
}

@Test
fun `token should have issuedAt set dynamically according to timeProvider`() {
val clock =
object : Clock() {
private var clock = systemDefaultZone()
val timeProvider =
object : TimeProvider {
var time = Instant.now()

override fun instant() = clock.instant()

override fun withZone(zone: ZoneId) = clock.withZone(zone)

override fun getZone() = clock.zone

fun fixed(instant: Instant) {
clock = fixed(instant, zone)
}
override fun invoke(): Instant = time
}

val tokenProvider = OAuth2TokenProvider { clock.instant() }
val tokenProvider = OAuth2TokenProvider(timeProvider = timeProvider)

val instant1 = Instant.parse("2000-12-03T10:15:30.00Z")
val instant2 = Instant.parse("2020-01-21T00:00:00.00Z")
instant1 shouldNotBe instant2

run {
clock.fixed(instant1)
tokenProvider.systemTime shouldBe instant1
timeProvider.time = instant1
tokenProvider.systemTime shouldBe instant1

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant1)
println(it.serialize())
}

run {
clock.fixed(instant2)
tokenProvider.systemTime shouldBe instant2
timeProvider.time = instant2
tokenProvider.systemTime shouldBe instant2

tokenProvider.exchangeAccessToken(
tokenRequest =
nimbusTokenRequest(
"id",
"grant_type" to GrantType.CLIENT_CREDENTIALS.value,
"scope" to "scope1",
),
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = tokenProvider.jwt(mapOf()).jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)
}.asClue {
tokenProvider.clientCredentialsToken("http://localhost/default").asClue {
it.jwtClaimsSet.issueTime shouldBe Date.from(instant2)
println(it.serialize())
}
}

@Test
fun `token with issueTime set to yesterday should be able to validate with the verify function using the same timeprovider`() {
val yesterday = Instant.now().minus(1, ChronoUnit.DAYS)
val tokenProvider = OAuth2TokenProvider(timeProvider = { yesterday })

val token = tokenProvider.clientCredentialsToken("http://localhost/default")

token.jwtClaimsSet.issueTime shouldBe Date.from(tokenProvider.systemTime)

tokenProvider.verify("http://localhost/default".toHttpUrl(), token.serialize()).toJSONObject().asClue {
it shouldBe token.jwtClaimsSet.toJSONObject()
}
}

private fun OAuth2TokenProvider.clientCredentialsToken(issuerUrl: String): SignedJWT =
accessToken(
tokenRequest =
nimbusTokenRequest(
"client1",
"grant_type" to "client_credentials",
"scope" to "scope1",
),
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private fun idToken(issuerUrl: String): SignedJWT =
tokenProvider.idToken(
tokenRequest =
Expand All @@ -198,4 +180,6 @@ internal class OAuth2TokenProviderRSATest {
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback(),
)

private infix fun Date.shouldBeAfter(instant: Instant?) = this.after(Date.from(instant)) shouldBe true
}
Loading