From 422c59467e281849aa174c4cb60c7457fc6c1f3a Mon Sep 17 00:00:00 2001 From: Anton Keks Date: Wed, 23 Aug 2023 12:01:29 +0300 Subject: [PATCH] add support for specifying @SecurityRequirement for routes --- openapi/src/OpenAPI.kt | 11 ++++++++--- openapi/test/OpenAPITest.kt | 16 +++++++++++----- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/openapi/src/OpenAPI.kt b/openapi/src/OpenAPI.kt index ea7708a..0635f33 100644 --- a/openapi/src/OpenAPI.kt +++ b/openapi/src/OpenAPI.kt @@ -10,6 +10,8 @@ import io.swagger.v3.oas.annotations.media.Schema.AccessMode import io.swagger.v3.oas.annotations.parameters.RequestBody import io.swagger.v3.oas.annotations.responses.ApiResponse import io.swagger.v3.oas.annotations.responses.ApiResponses +import io.swagger.v3.oas.annotations.security.SecurityRequirement +import io.swagger.v3.oas.annotations.security.SecurityRequirements import io.swagger.v3.oas.annotations.security.SecurityScheme import io.swagger.v3.oas.annotations.security.SecuritySchemes import io.swagger.v3.oas.annotations.tags.Tag @@ -55,7 +57,7 @@ internal fun Router.generateOpenAPI() = mapOf( "servers" to listOf(mapOf("url" to fullUrl(prefix))), "tags" to toTags(routes), "components" to mapOfNotNull( - "securitySchemes" to (route.annotations.filterIsInstance() + (route.annotation()?.value ?: emptyArray())).associate { s -> + "securitySchemes" to route.repeatableAnnotations().associate { s -> s.name to s.toNonEmptyValues { it.name != "paramName" }.let { it + ("name" to s.paramName) } }.takeIf { it.isNotEmpty() } ).takeIf { it.isNotEmpty() }, @@ -66,6 +68,9 @@ internal fun Router.generateOpenAPI() = mapOf( it.toNonEmptyValues() + ("security" to it.security.associate { it.name to it.scopes.toList() }) } ?: emptyMap()) +private inline fun Route.repeatableAnnotations() = + annotations.filterIsInstance() + (annotation()?.let { it.publicProperties.first().valueOf(it) as Array } ?: emptyArray()) + internal fun toTags(routes: List) = routes.asSequence() .map { it.handler } .filterIsInstance() @@ -83,7 +88,7 @@ internal fun toOperation(route: Route): Pair { }, "requestBody" to toRequestBody(route, route.annotation() ?: op?.requestBody), "responses" to toResponsesByCode(route, op, funHandler?.f?.returnType), - "security" to op?.security?.associate { it.name to it.scopes.toList() } + "security" to route.repeatableAnnotations().associate { it.name to it.scopes.toList() } ) + (op?.let { it.toNonEmptyValues { it.name !in setOf("method", "requestBody", "responses") } } ?: emptyMap()) } @@ -153,7 +158,7 @@ private fun toResponsesByCode(route: Route, op: Operation?, returnType: KType?): val responses = LinkedHashMap() if (returnType?.classifier == Unit::class) responses[NoContent] = mapOf("description" to "No content") else if (op?.responses?.isEmpty() != false) responses[OK] = mapOfNotNull("description" to "OK", "content" to returnType?.toJsonContent(response = true)) - (route.annotations.filterIsInstance() + (route.annotation()?.value ?: emptyArray()) + (op?.responses ?: emptyArray())).forEach { + (route.repeatableAnnotations() + (op?.responses ?: emptyArray())).forEach { responses[StatusCode(it.responseCode.toInt())] = it.toNonEmptyValues { it.name != "responseCode" } } return responses diff --git a/openapi/test/OpenAPITest.kt b/openapi/test/OpenAPITest.kt index 906723f..b20559b 100644 --- a/openapi/test/OpenAPITest.kt +++ b/openapi/test/OpenAPITest.kt @@ -77,7 +77,8 @@ class OpenAPITest { mapOf("name" to "force", "required" to false, "in" to QUERY, "schema" to mapOf("type" to "boolean")) ), "requestBody" to null, - "responses" to mapOf(OK to mapOf("description" to "OK", "content" to mapOf(MimeTypes.json to mapOf("schema" to mapOf("type" to "null"))))) + "responses" to mapOf(OK to mapOf("description" to "OK", "content" to mapOf(MimeTypes.json to mapOf("schema" to mapOf("type" to "null"))))), + "security" to emptyMap() )) } @@ -103,7 +104,8 @@ class OpenAPITest { mapOf("name" to "userId", "required" to true, "in" to PATH, "schema" to mapOf("type" to "string", "format" to "uuid")) ), "requestBody" to mapOf("content" to userSchema(), "required" to true), - "responses" to mapOf(NoContent to mapOf("description" to "No content")) + "responses" to mapOf(NoContent to mapOf("description" to "No content")), + "security" to emptyMap() )) } @@ -112,6 +114,7 @@ class OpenAPITest { @RequestBody(description = "Application and applicant", content = [Content(mediaType = MimeTypes.json, schema = Schema(implementation = User::class))]) @ApiResponse(responseCode = "400", description = "Very bad request") @ApiResponse(responseCode = "401", description = "Unauthorized") + @SecurityRequirement(name = "MySecurity") fun saveUser(e: HttpExchange): User = User("x", UUID.randomUUID()) } expect(toOperation(Route(POST, "/x".toRegex(), handler = FunHandler(MyRoutes(), MyRoutes::saveUser), annotations = MyRoutes::saveUser.annotations))).toEqual("post" to mapOf( @@ -123,7 +126,8 @@ class OpenAPITest { OK to mapOf("description" to "OK", "content" to userSchema(response = true)), BadRequest to mapOf("description" to "Very bad request"), Unauthorized to mapOf("description" to "Unauthorized"), - ) + ), + "security" to mapOf("MySecurity" to emptyList()) )) } @@ -133,7 +137,8 @@ class OpenAPITest { "tags" to emptyList(), "parameters" to null, "requestBody" to null, - "responses" to mapOf(OK to mapOf("description" to "OK")) + "responses" to mapOf(OK to mapOf("description" to "OK")), + "security" to emptyMap() )) } @@ -151,7 +156,8 @@ class OpenAPITest { "parameters" to listOf(mapOf("name" to "param", "in" to QUERY, "description" to "description")), "requestBody" to null, "responses" to mapOf(Found to mapOf("description" to "desc")), - "summary" to "summary" + "summary" to "summary", + "security" to emptyMap() )) }