Skip to content

Commit

Permalink
Tokenexchange errorhandling (#31)
Browse files Browse the repository at this point in the history
*  tokendings errorhandling
* azureexchange errorhandling

Closes #29
  • Loading branch information
rannveigskjerve authored May 6, 2024
1 parent 2633080 commit ffa5885
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 57 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.gradle
build
.idea
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,50 @@ import no.nav.tms.token.support.azure.exchange.consumer.AzureConsumer


class NonCachingAzureService internal constructor(
private val azureConsumer: AzureConsumer,
issuer: String,
clientId: String,
privateJwk: String
): AzureService {
private val azureConsumer: AzureConsumer,
issuer: String,
clientId: String,
privateJwk: String
) : AzureService {

private val clientAssertionService = ClientAssertionService(privateJwk, clientId, issuer)

override suspend fun getAccessToken(targetApp: String): String {
override suspend fun getAccessToken(targetApp: String): String = try {
val jwt = clientAssertionService.createClientAssertion()

return azureConsumer.fetchToken(jwt, targetApp).accessToken
azureConsumer.fetchToken(jwt, targetApp).accessToken
} catch (throwable: Throwable) {
throw AzureExchangeException(throwable, targetApp)
}


}

class CachingAzureService internal constructor(
private val azureConsumer: AzureConsumer,
issuer: String,
clientId: String,
privateJwk: String,
maxCacheEntries: Long,
cacheExpiryMarginSeconds: Int,
): AzureService {
private val azureConsumer: AzureConsumer,
issuer: String,
clientId: String,
privateJwk: String,
maxCacheEntries: Long,
cacheExpiryMarginSeconds: Int,
) : AzureService {

private val cache = CacheBuilder.buildCache(maxCacheEntries, cacheExpiryMarginSeconds)

private val clientAssertionService = ClientAssertionService(privateJwk, clientId, issuer)


override suspend fun getAccessToken(targetApp: String): String {
override suspend fun getAccessToken(targetApp: String): String =
try {
cache.get(targetApp) {
runBlocking {
performTokenExchange(targetApp)
}
}.accessToken
} catch (throwable: Throwable) {
throw AzureExchangeException(throwable, targetApp)
}

return cache.get(targetApp) {
runBlocking {
performTokenExchange(targetApp)
}
}.accessToken
}

private suspend fun performTokenExchange(targetApp: String): AccessTokenEntry {
val jwt = clientAssertionService.createClientAssertion()
Expand All @@ -54,3 +61,16 @@ class CachingAzureService internal constructor(
return AccessTokenEntry.fromResponse(response)
}
}

class AzureExchangeException(val originalThrowable: Throwable, targetApp: String) :
Exception() {

val stackTraceSummary =
originalThrowable.stackTrace.firstOrNull()?.let { stacktraceElement ->
""" Azureexchange feiler for $targetApp
Origin: ${stacktraceElement.fileName ?: "---"} ${stacktraceElement.methodName ?: "----"} linenumber:${stacktraceElement.lineNumber}
Message: "${originalThrowable::class.simpleName} ${originalThrowable.message?.let { ":$it" }}"
""".trimIndent()
} ?: "${originalThrowable::class.simpleName} ${originalThrowable.message?.let { ":$it" }}"
}

Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import io.kotest.matchers.shouldNotBe
import io.mockk.*
import kotlinx.coroutines.runBlocking
import no.nav.tms.token.support.azure.exchange.consumer.AzureConsumer
import no.nav.tms.token.support.azure.exchange.service.AzureExchangeException
import no.nav.tms.token.support.azure.exchange.service.CachingAzureService
import no.nav.tms.token.support.azure.exchange.service.NonCachingAzureService
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.net.SocketTimeoutException

internal class AzureServiceTest {

Expand Down Expand Up @@ -91,7 +94,7 @@ internal class AzureServiceTest {
cachingAzureService.getAccessToken(target)
}

coVerify(exactly = 1) {azureConsumer.fetchToken(any(), target) }
coVerify(exactly = 1) { azureConsumer.fetchToken(any(), target) }
}

@Test
Expand All @@ -110,7 +113,7 @@ internal class AzureServiceTest {
cachingAzureService.getAccessToken(target)
}

coVerify(exactly = 3) {azureConsumer.fetchToken(any(), target) }
coVerify(exactly = 3) { azureConsumer.fetchToken(any(), target) }
}

@Test
Expand All @@ -134,12 +137,55 @@ internal class AzureServiceTest {
val result3 = runBlocking { cachingAzureService.getAccessToken(target1) }
val result4 = runBlocking { cachingAzureService.getAccessToken(target2) }

coVerify(exactly = 1) {azureConsumer.fetchToken(any(), target1) }
coVerify(exactly = 1) {azureConsumer.fetchToken(any(), target2) }
coVerify(exactly = 1) { azureConsumer.fetchToken(any(), target1) }
coVerify(exactly = 1) { azureConsumer.fetchToken(any(), target2) }

result1 shouldBe result3
result2 shouldBe result4
result1 shouldNotBe result2
result3 shouldNotBe result4
}

@Test
fun `Should throw AzureExchangeException if exchangeprocess fails`() {
assertNonCachingServiceThrows { IllegalArgumentException() }
assertNonCachingServiceThrows { SocketTimeoutException() }
assertNonCachingServiceThrows { Error() }
assertCachingServiceThrows { IllegalArgumentException() }
assertCachingServiceThrows { SocketTimeoutException() }
assertCachingServiceThrows { Error() }

}

fun assertNonCachingServiceThrows( throwable: () -> Throwable) = run {
NonCachingAzureService(
azureConsumer = mockk<AzureConsumer>().apply {
coEvery { fetchToken(any(), any()) } throws throwable()
},
clientId = "some:client",
issuer = "some:issuer",
privateJwk = privateJwk,
).apply {
assertThrows<AzureExchangeException> { runBlocking { getAccessToken("appappapp") } }
}
}

fun assertCachingServiceThrows(throwable: () -> Throwable) = run {
CachingAzureService(
azureConsumer = mockk<AzureConsumer>().apply {
coEvery { fetchToken(any(), any()) } throws throwable()
},
clientId = "some:client",
privateJwk = privateJwk,
issuer = "some:issuer",
maxCacheEntries = 1,
cacheExpiryMarginSeconds = 6

).apply {
assertThrows<AzureExchangeException> { runBlocking { getAccessToken("token") } }
}
}


}

Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ internal class TokendingsConsumer(
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,51 @@ import no.nav.tms.token.support.tokendings.exchange.consumer.TokendingsConsumer
import no.nav.tms.token.support.tokendings.exchange.service.ClientAssertion.createSignedAssertion

class NonCachingTokendingsService internal constructor(
private val tokendingsConsumer: TokendingsConsumer,
private val jwtAudience: String,
private val clientId: String,
privateJwk: String
): TokendingsService {
private val tokendingsConsumer: TokendingsConsumer,
private val jwtAudience: String,
private val clientId: String,
privateJwk: String
) : TokendingsService {

private val privateRsaKey = RSAKey.parse(privateJwk)

override suspend fun exchangeToken(token: String, targetApp: String): String {
val jwt = createSignedAssertion(clientId, jwtAudience, privateRsaKey)
try {
val jwt = createSignedAssertion(clientId, jwtAudience, privateRsaKey)

return tokendingsConsumer.exchangeToken(token, jwt, targetApp).accessToken
} catch (throwable: Throwable) {
throw TokendingsExchangeException(throwable, clientId)
}

return tokendingsConsumer.exchangeToken(token, jwt, targetApp).accessToken
}
}

class CachingTokendingsService internal constructor(
private val tokendingsConsumer: TokendingsConsumer,
private val jwtAudience: String,
private val clientId: String,
privateJwk: String,
maxCacheEntries: Long,
cacheExpiryMarginSeconds: Int,
): TokendingsService {
private val tokendingsConsumer: TokendingsConsumer,
private val jwtAudience: String,
private val clientId: String,
privateJwk: String,
maxCacheEntries: Long,
cacheExpiryMarginSeconds: Int,
) : TokendingsService {

private val cache = CacheBuilder.buildCache(maxCacheEntries, cacheExpiryMarginSeconds)

private val privateRsaKey = RSAKey.parse(privateJwk)

override suspend fun exchangeToken(token: String, targetApp: String): String {
val cacheKey = TokenStringUtil.createCacheKey(token, targetApp)

return cache.get(cacheKey) {
runBlocking {
performTokenExchange(token, targetApp)
}
}.accessToken
try {
val cacheKey = TokenStringUtil.createCacheKey(token, targetApp)

return cache.get(cacheKey) {
runBlocking {
performTokenExchange(token, targetApp)
}
}.accessToken
} catch (throwable: Throwable) {
throw TokendingsExchangeException(throwable, clientId)
}
}

private suspend fun performTokenExchange(token: String, targetApp: String): AccessTokenEntry {
Expand All @@ -68,3 +77,16 @@ internal object TokenStringUtil {
return AccessTokenKey(subject, securityLevel, targetApp)
}
}

class TokendingsExchangeException(val originalThrowable: Throwable, clientId: String) :
Exception() {

val stackTraceSummary =
originalThrowable.stackTrace.firstOrNull()?.let { stacktraceElement ->
""" Tokendingsexchange feiler for $clientId
Origin: ${stacktraceElement.fileName ?: "---"} ${stacktraceElement.methodName ?: "----"} linenumber:${stacktraceElement.lineNumber}
Message: "${originalThrowable::class.simpleName} ${originalThrowable.message?.let { ":$it" }}"
""".trimIndent()
} ?: "${originalThrowable::class.simpleName} ${originalThrowable.message?.let { ":$it" }}"
}

Loading

0 comments on commit ffa5885

Please sign in to comment.