Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(build): netty-all with netty-codec-http #723

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import java.time.Duration
import com.github.benmanes.gradle.versions.updates.DependencyUpdatesTask
import org.jetbrains.kotlin.gradle.dsl.JvmTarget

val assertjVersion = "3.26.3"
val kotlinLoggingVersion = "3.0.5"
Expand Down Expand Up @@ -61,7 +62,7 @@ dependencies {
implementation("ch.qos.logback:logback-classic:$logbackVersion")
api("com.squareup.okhttp3:mockwebserver:$mockWebServerVersion")
api("com.nimbusds:oauth2-oidc-sdk:$nimbusSdkVersion")
implementation("io.netty:netty-all:$nettyVersion")
implementation("io.netty:netty-codec-http:$nettyVersion")
implementation("io.github.microutils:kotlin-logging:$kotlinLoggingVersion")
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion")
implementation("org.freemarker:freemarker:$freemarkerVersion")
Expand Down Expand Up @@ -123,7 +124,7 @@ dependencies {
}

configurations {
all {
all {
resolutionStrategy.force("com.fasterxml.woodstox:woodstox-core:7.0.0")
}
}
Expand Down Expand Up @@ -289,8 +290,8 @@ tasks {
}

withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
kotlinOptions {
jvmTarget = JavaVersion.VERSION_17.toString()
compilerOptions {
jvmTarget.set(JvmTarget.JVM_17)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ open class MockOAuth2Server(
@Deprecated("Use MockWebServer method/function instead", ReplaceWith("MockWebServer.enqueue()"))
fun enqueueResponse(
@Suppress("UNUSED_PARAMETER") response: MockResponse,
) {
throw UnsupportedOperationException("cannot enqueue MockResponse, please use the MockWebServer directly with QueueDispatcher")
}
): Unit = throw UnsupportedOperationException("cannot enqueue MockResponse, please use the MockWebServer directly with QueueDispatcher")

/**
* Enqueues a callback at the server's HTTP request handler.
Expand Down Expand Up @@ -328,7 +326,8 @@ open class MockOAuth2Server(
}

internal fun Map<String, Any>.toJwtClaimsSet(): JWTClaimsSet =
JWTClaimsSet.Builder()
JWTClaimsSet
.Builder()
.apply {
[email protected] {
this.claim(it.key, it.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ data class OAuth2Config
}

companion object {
fun fromJson(json: String): OAuth2Config {
return jacksonObjectMapper().readValue(json)
}
fun fromJson(json: String): OAuth2Config = jacksonObjectMapper().readValue(json)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import com.nimbusds.oauth2.sdk.OAuth2Error
import com.nimbusds.oauth2.sdk.http.HTTPResponse

@Suppress("unused")
class OAuth2Exception(val errorObject: ErrorObject?, msg: String, throwable: Throwable?) :
RuntimeException(msg, throwable) {
class OAuth2Exception(
val errorObject: ErrorObject?,
msg: String,
throwable: Throwable?,
) : RuntimeException(msg, throwable) {
constructor(msg: String) : this(null, msg, null)
constructor(msg: String, throwable: Throwable?) : this(null, msg, throwable)
constructor(errorObject: ErrorObject?, msg: String) : this(errorObject, msg, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ object StandaloneConfig {
const val PORT = "PORT" // Supports running Docker image on Heroku.

fun hostname(): InetAddress =
SERVER_HOSTNAME.fromEnv()
SERVER_HOSTNAME
.fromEnv()
?.let { InetAddress.getByName(it) } ?: InetSocketAddress(0).address

fun port(): Int = (SERVER_PORT.fromEnv()?.toInt() ?: PORT.fromEnv()?.toInt()) ?: 8080
Expand Down
37 changes: 22 additions & 15 deletions src/main/kotlin/no/nav/security/mock/oauth2/debugger/Client.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ internal class TokenRequest(
"\n\n$body"

private fun Map<String, String>.toKeyValueString(entrySeparator: String): String =
this.map { "${it.key}=${it.value}" }
.toList().joinToString(entrySeparator)
this
.map { "${it.key}=${it.value}" }
.toList()
.joinToString(entrySeparator)
}

internal data class ClientAuthentication(
Expand Down Expand Up @@ -79,21 +81,26 @@ internal data class ClientAuthentication(
internal fun String.urlEncode(): String = URLEncoder.encode(this, StandardCharsets.UTF_8)

internal fun OkHttpClient.post(tokenRequest: TokenRequest): String =
this.newCall(
Request.Builder()
.headers(tokenRequest.headers)
.url(tokenRequest.url)
.post(tokenRequest.body.toRequestBody("application/x-www-form-urlencoded".toMediaType()))
.build(),
).execute().body?.string() ?: throw RuntimeException("could not get response body from url=${tokenRequest.url}")
this
.newCall(
Request
.Builder()
.headers(tokenRequest.headers)
.url(tokenRequest.url)
.post(tokenRequest.body.toRequestBody("application/x-www-form-urlencoded".toMediaType()))
.build(),
).execute()
.body
?.string() ?: throw RuntimeException("could not get response body from url=${tokenRequest.url}")

fun OkHttpClient.withSsl(
ssl: Ssl,
followRedirects: Boolean = false,
): OkHttpClient =
newBuilder().apply {
followRedirects(followRedirects)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply { init(ssl.sslKeystore.keyStore) }
val sslContext = SSLContext.getInstance("TLS").apply { init(null, trustManagerFactory.trustManagers, null) }
sslSocketFactory(sslContext.socketFactory, trustManagerFactory.trustManagers[0] as X509TrustManager)
}.build()
newBuilder()
.apply {
followRedirects(followRedirects)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply { init(ssl.sslKeystore.keyStore) }
val sslContext = SSLContext.getInstance("TLS").apply { init(null, trustManagerFactory.trustManagers, null) }
sslSocketFactory(sslContext.socketFactory, trustManagerFactory.trustManagers[0] as X509TrustManager)
}.build()
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,27 @@ private fun Route.Builder.debuggerForm(sessionManager: SessionManager) =
get(DEBUGGER) {
log.debug("handling GET request, return html form")
val url =
it.url.toAuthorizationEndpointUrl().newBuilder().query(
"client_id=debugger" +
"&response_type=code" +
"&redirect_uri=${it.url.toDebuggerCallbackUrl()}" +
"&response_mode=query" +
"&scope=openid+somescope" +
"&state=1234" +
"&nonce=5678",
).build()
it.url
.toAuthorizationEndpointUrl()
.newBuilder()
.query(
"client_id=debugger" +
"&response_type=code" +
"&redirect_uri=${it.url.toDebuggerCallbackUrl()}" +
"&response_mode=query" +
"&scope=openid+somescope" +
"&state=1234" +
"&nonce=5678",
).build()
html(templateMapper.debuggerFormHtml(url, "CLIENT_SECRET_BASIC"))
}
post(DEBUGGER) {
log.debug("handling POST request, return redirect")
val authorizeUrl = it.formParameters.get("authorize_url") ?: error("authorize_url is missing")
val httpUrl =
authorizeUrl.toHttpUrl().newBuilder()
authorizeUrl
.toHttpUrl()
.newBuilder()
.encodedQuery(it.formParameters.parameterString)
.removeAllEncodedQueryParams("authorize_url", "token_url", "client_secret", "client_auth_method")
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ private val log = KotlinLogging.logger { }

class SessionManager {
private val encryptionKey: SecretKey =
KeyGenerator.getInstance("AES")
.apply { this.init(128) }.generateKey()
KeyGenerator
.getInstance("AES")
.apply { this.init(128) }
.generateKey()

fun session(request: OAuth2HttpRequest): Session = Session(encryptionKey, request)

Expand Down Expand Up @@ -52,9 +54,12 @@ class SessionManager {
}.serialize()

private fun String.decrypt(key: SecretKey): String =
JWEObject.parse(this).also {
it.decrypt(DirectDecrypter(key))
}.payload.toString()
JWEObject
.parse(this)
.also {
it.decrypt(DirectDecrypter(key))
}.payload
.toString()

private fun getSessionCookie(): String? =
runCatching {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ private fun HttpUrl.issuer(path: String = ""): HttpUrl =
private fun joinPaths(vararg path: String) = path.filter { it.isNotEmpty() }.joinToString("/") { it.trimPath() }

private fun HttpUrl.baseUrl(): HttpUrl =
HttpUrl.Builder()
HttpUrl
.Builder()
.scheme(this.scheme)
.host(this.host)
.port(this.port)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ fun ClientAuthentication.requirePrivateKeyJwt(
it.clientAssertion.expiresIn() > maxLifetimeSeconds -> {
invalidRequest("invalid client_assertion: client_assertion expiry is too long( should be < $maxLifetimeSeconds)")
}
!it.clientAssertion.jwtClaimsSet.audience.contains(requiredAudience) -> {
!it.clientAssertion.jwtClaimsSet.audience
.contains(requiredAudience) -> {
invalidRequest("invalid client_assertion: client_assertion must contain required audience '$requiredAudience'")
}
else -> it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import java.net.URLDecoder
import java.nio.charset.StandardCharsets

internal fun String.keyValuesToMap(listDelimiter: String): Map<String, String> =
this.split(listDelimiter)
this
.split(listDelimiter)
.filter { it.contains("=") }
.associate {
val (key, value) = it.split("=")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ fun Map<String, Any>.replaceValues(templates: Map<String, Any>): Map<String, Any
}
}

fun replaceValue(value: Any): Any {
return when (value) {
fun replaceValue(value: Any): Any =
when (value) {
is String -> replaceTemplateString(value, templates)
is List<*> -> value.map { it?.let { replaceValue(it) } }
is Map<*, *> -> value.mapValues { v -> v.value?.let { replaceValue(it) } }
else -> value
}
}

return this.mapValues { replaceValue(it.value) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,19 @@ internal class AuthorizationCodeHandler(
private fun getLoginTokenCallbackOrDefault(
code: AuthorizationCode,
OAuth2TokenCallback: OAuth2TokenCallback,
): OAuth2TokenCallback {
return takeLoginFromCache(code)?.let {
): OAuth2TokenCallback =
takeLoginFromCache(code)?.let {
LoginOAuth2TokenCallback(it, OAuth2TokenCallback)
} ?: OAuth2TokenCallback
}

private fun takeLoginFromCache(code: AuthorizationCode): Login? = codeToLoginCache.remove(code)

private fun takeAuthenticationRequestFromCache(code: AuthorizationCode): AuthenticationRequest? = codeToAuthRequestCache.remove(code)

private class LoginOAuth2TokenCallback(val login: Login, val oAuth2TokenCallback: OAuth2TokenCallback) : OAuth2TokenCallback {
private class LoginOAuth2TokenCallback(
val login: Login,
val oAuth2TokenCallback: OAuth2TokenCallback,
) : OAuth2TokenCallback {
override fun issuerId(): String = oAuth2TokenCallback.issuerId()

override fun subject(tokenRequest: TokenRequest): String = login.username
Expand All @@ -116,7 +118,8 @@ internal class AuthorizationCodeHandler(
oAuth2TokenCallback.addClaims(tokenRequest).toMutableMap().apply {
login.claims?.let {
try {
jsonMapper.readTree(it)
jsonMapper
.readTree(it)
.fields()
.forEach { field ->
put(field.key, jsonMapper.readValue(field.value.toString()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.HttpUrl

internal class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvider) : GrantHandler {
internal class JwtBearerGrantHandler(
private val tokenProvider: OAuth2TokenProvider,
) : GrantHandler {
override fun tokenResponse(
request: OAuth2HttpRequest,
issuerUrl: HttpUrl,
Expand All @@ -36,11 +38,10 @@ internal class JwtBearerGrantHandler(private val tokenProvider: OAuth2TokenProvi
)
}

private fun TokenRequest.responseScope(): String {
return scope?.toString()
private fun TokenRequest.responseScope(): String =
scope?.toString()
?: assertion().getClaim("scope")?.toString()
?: invalidRequest("scope must be specified in request or as a claim in assertion parameter")
}

private fun TokenRequest.assertion(): JWTClaimsSet =
(this.authorizationGrant as? JWTBearerGrant)?.jwtAssertion?.jwtClaimsSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class TokenExchangeGrant(
TokenExchangeGrant(
parameters.require("subject_token_type"),
parameters.require("subject_token"),
parameters.require("audience")
parameters
.require("audience")
.split(" ")
.toMutableList(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
import no.nav.security.mock.oauth2.token.OAuth2TokenProvider
import okhttp3.HttpUrl

internal class TokenExchangeGrantHandler(private val tokenProvider: OAuth2TokenProvider) : GrantHandler {
internal class TokenExchangeGrantHandler(
private val tokenProvider: OAuth2TokenProvider,
) : GrantHandler {
override fun tokenResponse(
request: OAuth2HttpRequest,
issuerUrl: HttpUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ data class OAuth2HttpRequest(
)

internal fun proxyAwareUrl(): HttpUrl =
HttpUrl.Builder()
HttpUrl
.Builder()
.scheme(resolveScheme())
.host(resolveHost())
.port(resolvePort())
Expand Down Expand Up @@ -127,7 +128,9 @@ data class OAuth2HttpRequest(
return null
}

data class Parameters(val parameterString: String?) {
data class Parameters(
val parameterString: String?,
) {
val map: Map<String, String> = parameterString?.keyValuesToMap("&") ?: emptyMap()

fun get(name: String): String? = map[name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ import java.util.concurrent.LinkedBlockingQueue

private val log = KotlinLogging.logger {}

class OAuth2HttpRequestHandler(private val config: OAuth2Config) {
class OAuth2HttpRequestHandler(
private val config: OAuth2Config,
) {
private val loginRequestHandler = LoginRequestHandler(templateMapper, config)
private val debuggerRequestHandler = DebuggerRequestHandler(ssl = config.httpServer.sslConfig())
private val tokenCallbackQueue: BlockingQueue<OAuth2TokenCallback> = LinkedBlockingQueue()
Expand Down Expand Up @@ -180,7 +182,10 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) {
apply {
if (config.staticAssetsPath != null) {
get("/static/*") {
val path = it.url.pathSegments.drop(1).joinToString("/")
val path =
it.url.pathSegments
.drop(1)
.joinToString("/")
val normalized = Paths.get(path).normalize().toString()
val file = File(config.staticAssetsPath, normalized)

Expand Down
Loading