Skip to content

Commit

Permalink
fix for discriminated union transform
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjkl committed Jul 16, 2024
1 parent 310edbb commit e817957
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ final class AlloyOpenApiExtension() extends Smithy2OpenApiExtension {
new UntaggedUnions(),
new DataExamplesMapper(),
new ExternalDocumentationMapperJsonSchema(),
new NullableMapper()
new NullableMapper(),
new DiscriminatedUnionShapeId()
).asJava

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,20 @@ class DiscriminatedUnionMemberComponents() extends OpenApiMapper {
.getModel()
.getUnionShapesWithTrait(classOf[DiscriminatedUnionTrait])
val componentBuilder = openapi.getComponents().toBuilder()
val componentSchemas: Map[String, Schema] = openapi
val componentSchemas: Map[ShapeId, Schema] = openapi
.getComponents()
.getSchemas()
.asScala
.toMap
.flatMap { case (_, schema) =>
schema
.getExtension(DiscriminatedUnionShapeId.SHAPE_ID_KEY)
.asScala
.flatMap { node =>
node.toNode.asStringNode.asScala
.map(s => ShapeId.from(s.getValue) -> schema)
}
}
unions.asScala.foreach { union =>
val unionMixinName = union.getId().getName() + "Mixin"
val unionMixinId =
Expand Down Expand Up @@ -82,19 +91,19 @@ class DiscriminatedUnionMemberComponents() extends OpenApiMapper {
componentBuilder.putSchema(syntheticMemberName, syntheticUnionMember)
}

val existingSchemaBuilder = componentSchemas
.get(union.toShapeId.getName)
.map(_.toBuilder())
.getOrElse(Schema.builder())
componentBuilder.putSchema(
union.toShapeId.getName,
updateDiscriminatedUnion(
union,
existingSchemaBuilder,
discriminatorField
)
.build()
)
componentSchemas.get(union.toShapeId).foreach { sch =>
if (!sch.getOneOf.isEmpty) {
componentBuilder.putSchema(
union.toShapeId.getName,
updateDiscriminatedUnion(
union,
sch.toBuilder(),
discriminatorField
)
.build()
)
}
}

}
openapi.toBuilder.components(componentBuilder.build()).build()
Expand Down Expand Up @@ -130,6 +139,7 @@ class DiscriminatedUnionMemberComponents() extends OpenApiMapper {
.asJava
)
schemaBuilder
.removeExtension(DiscriminatedUnionShapeId.SHAPE_ID_KEY)
.oneOf(schemas)
.putExtension(
"discriminator",
Expand Down
44 changes: 44 additions & 0 deletions modules/openapi/src/alloy/openapi/DiscriminatedUnionShapeId.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2022 Disney Streaming
*
* Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://disneystreaming.github.io/TOST-1.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package alloy.openapi

import _root_.software.amazon.smithy.jsonschema.JsonSchemaConfig
import _root_.software.amazon.smithy.jsonschema.JsonSchemaMapper
import _root_.software.amazon.smithy.jsonschema.Schema.Builder
import _root_.software.amazon.smithy.model.shapes.Shape
import alloy.DiscriminatedUnionTrait

import software.amazon.smithy.model.node.Node

class DiscriminatedUnionShapeId() extends JsonSchemaMapper {

import DiscriminatedUnionShapeId._

override def updateSchema(
shape: Shape,
schemaBuilder: Builder,
config: JsonSchemaConfig
): Builder = if (shape.hasTrait(classOf[DiscriminatedUnionTrait])) {
schemaBuilder.putExtension(
SHAPE_ID_KEY,
Node.from(shape.toShapeId.toString)
)
} else schemaBuilder
}

object DiscriminatedUnionShapeId {
private[openapi] val SHAPE_ID_KEY: String = "SHAPE_ID"
}
67 changes: 67 additions & 0 deletions modules/openapi/test/resources/bar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"openapi": "3.0.2",
"info": {
"title": "BarService",
"version": ""
},
"paths": {
"/bar": {
"get": {
"operationId": "BarOp",
"responses": {
"200": {
"description": "BarOp200response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BarOpResponseContent"
}
}
}
}
}
}
}
},
"components": {
"schemas": {
"BarOpResponseContent": {
"type": "object",
"properties": {
"out": {
"$ref": "#/components/schemas/CatOrDog"
}
}
},
"CatOrDog": {
"oneOf": [
{
"type": "object",
"title": "one",
"properties": {
"one": {
"type": "string"
}
},
"required": [
"one"
]
},
{
"type": "object",
"title": "two",
"properties": {
"two": {
"type": "integer",
"format": "int32"
}
},
"required": [
"two"
]
}
]
}
}
}
}
22 changes: 22 additions & 0 deletions modules/openapi/test/resources/bar.smithy
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
$version: "2"

namespace bar

use alloy#simpleRestJson

@simpleRestJson
service BarService {
operations: [BarOp]
}

@http(method: "GET", uri: "/bar")
operation BarOp {
output := {
out: CatOrDog
}
}

union CatOrDog {
one: String
two: Integer
}
24 changes: 24 additions & 0 deletions modules/openapi/test/src/alloy/openapi/OpenApiConversionSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,30 @@ final class OpenApiConversionSpec extends munit.FunSuite {
assertEquals(result, expected)
}

test(
"OpenAPI conversion from alloy#simpleRestJson protocol with multiple namespaces"
) {
val model = Model
.assembler()
.addImport(getClass().getClassLoader().getResource("foo.smithy"))
.addImport(getClass().getClassLoader().getResource("bar.smithy"))
.discoverModels()
.assemble()
.unwrap()

val result = convert(model, Some(Set("bar")))
.map(_.contents)
.mkString
.filterNot(_.isWhitespace)

val expected = Using
.resource(Source.fromResource("bar.json"))(
_.getLines().mkString.filterNot(_.isWhitespace)
)

assertEquals(result, expected)
}

test("OpenAPI conversion from testJson protocol") {
val model = Model
.assembler()
Expand Down

0 comments on commit e817957

Please sign in to comment.