diff --git a/.scalafmt.conf b/.scalafmt.conf index 31a0d1f..6066985 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,2 +1,4 @@ version = "3.7.3" -runner.dialect = scala213source3 \ No newline at end of file +runner.dialect = scala213source3 +assumeStandardLibraryStripMargin = true +align.stripMargin = true diff --git a/README.md b/README.md index 1009436..f2f6f6a 100644 --- a/README.md +++ b/README.md @@ -44,16 +44,6 @@ _Note: this library is published to work on Java 8 and above. However, you will - [Protobuf](#protobuf) - [CLI Usage](#cli-usage-3) - [Capabilities and Design](#capabilities-and-design-2) - - [Primitives](#primitives-1) - - [Aggregate Types](#aggregate-types) - - [Structure](#structure-1) - - [Union](#union) - - [List](#list-1) - - [Map](#map-1) - - [Constraints](#constraints-1) - - [Enum](#enum-1) - - [Service Shapes](#service-shapes-1) - - [Basic Service](#basic-service-1) - [Options](#options) - [Stringly typed options](#stringly-typed-options) - [Example](#example) @@ -153,26 +143,26 @@ Smithy. #### Primitives -| OpenAPI Base Type | OpenAPI Format | Smithy Shape | Smithy Trait(s) | -|-------------------|--------------------|----------------------------|-------------------------------| -| string | | String | | -| string | timestamp | Timestamp | | -| string | date-time | Timestamp | @timestampFormat("date-time") | -| string | date | String | alloy#dateFormat | -| string | uuid | alloy#UUID | | -| string | binary | Blob | | -| string | byte | Blob | | -| string | password | String | @sensitive | -| number | float | Float | | -| number | double | Double | | -| number | double | Double | | -| number | | Double | | -| integer | int16 | Short | | -| integer | | Integer | | -| integer | int32 | Integer | | -| integer | int64 | Long | | -| boolean | | Boolean | | -| object | (empty properties) | Document | | +| OpenAPI Base Type | OpenAPI Format | Smithy Shape | Smithy Trait(s) | +| ----------------- | ------------------ | ------------ | ----------------------------- | +| string | | String | | +| string | timestamp | Timestamp | | +| string | date-time | Timestamp | @timestampFormat("date-time") | +| string | date | String | alloy#dateFormat | +| string | uuid | alloy#UUID | | +| string | binary | Blob | | +| string | byte | Blob | | +| string | password | String | @sensitive | +| number | float | Float | | +| number | double | Double | | +| number | double | Double | | +| number | | Double | | +| integer | int16 | Short | | +| integer | | Integer | | +| integer | int32 | Integer | | +| integer | int64 | Long | | +| boolean | | Boolean | | +| object | (empty properties) | Document | | #### Aggregate Shapes @@ -1222,282 +1212,7 @@ Run `smithytranslate smithy-to-proto --help` for more usage information. ### Capabilities and Design -#### Primitives - -There are more precises number scalar types in protobuf that don't exist in Smithy. For reference, see [here](https://developers.google.com/protocol-buffers/docs/proto3#scalar). You can still model those using the `@protoNumType` trait. The `@required` trait also has an effect on the final protobuf type because we use Google's wrapper types. See the following table for an exhaustive list: - -| Smithy type | @protoNumType | @required | Proto | -| -------------------- | ------------- | --------- | ---------------------------- | -| bigDecimal | N/A | N/A | message { string value = 1 } | -| bigInteger | N/A | N/A | message { string value = 1 } | -| blob | N/A | false | google.protobuf.BytesValue | -| blob | N/A | true | bytes | -| boolean | N/A | false | google.protobuf.BoolValue | -| boolean | N/A | true | bool | -| double | N/A | false | google.protobuf.DoubleValue | -| double | N/A | true | double | -| float | N/A | false | google.protobuf.FloatValue | -| float | N/A | true | float | -| integer, byte, short | FIXED | false | google.protobuf.Int32Value | -| integer, byte, short | FIXED | true | fixed32 | -| integer, byte, short | FIXED_SIGNED | false | google.protobuf.Int32Value | -| integer, byte, short | FIXED_SIGNED | true | sfixed32 | -| integer, byte, short | N/A | true | google.protobuf.Int32Value | -| integer, byte, short | N/A | true | int32 | -| integer, byte, short | SIGNED | false | google.protobuf.Int32Value | -| integer, byte, short | SIGNED | true | sint32 | -| integer, byte, short | UNSIGNED | false | google.protobuf.UInt32Value | -| integer, byte, short | UNSIGNED | true | uint32 | -| long | FIXED | false | google.protobuf.Int64Value | -| long | FIXED | true | fixed64 | -| long | FIXED_SIGNED | false | google.protobuf.Int64Value | -| long | FIXED_SIGNED | true | sfixed64 | -| long | N/A | true | google.protobuf.Int64Value | -| long | N/A | true | int64 | -| long | SIGNED | false | google.protobuf.Int64Value | -| long | SIGNED | true | sint64 | -| long | UNSIGNED | false | google.protobuf.UInt64Value | -| long | UNSIGNED | true | uint64 | -| string | N/A | false | google.protobuf.StringValue | -| string | N/A | true | string | -| timestamp | N/A | N/A | message { long value = 1 } | - -_Note: we can see from the table that the `@protoNumType` has no effect on non-required integer/long (except `UNSIGNED`). This is because there are no FIXED, FIXED_SIGNED or SIGNED instances in the Google's protobuf wrappers_ - -Smithy Translate has special support for `alloy#UUID`. A custom `message` is used in place of `alloy#UUID`. This message is defined as such and it is optmized for compactness: - -Smithy: -```smithy -structure UUID { - @required - upper_bits: Long - @required - lower_bits: Long -} -``` - -Proto: -```proto -message UUID { - int64 upper_bits = 1; - int64 lower_bits = 2; -} -``` - -#### Aggregate Types - -##### Structure - -Smithy: -```smithy -structure Testing { - myString: String, - myInt: Integer -} -``` - -Proto: -```proto -import "google/protobuf/wrappers.proto"; - -message Testing { - google.protobuf.StringValue myString = 1; - google.protobuf.Int32Value myInt = 2; -} -``` - -##### Union - -Unions in Smithy are tricky to translate to Protobuf because of the nature of `oneOf`. The default encoding will create a top-level `message` that contains a `definition` field which is the `oneOf`. For example: - -Smithy: -```smithy -structure Union { - @required - value: TestUnion -} - -union TestUnion { - num: Integer, - txt: String -} -``` - -Proto: -```proto -message Union { - foo.TestUnion value = 1; -} - -message TestUnion { - oneof definition { - int32 num = 1; - string txt = 2; - } -} -``` - -But you can also use `@protoInlinedOneOf` from `alloy` to render the `oneOf` inside of a specific message. This encoding can be harder to maintain because the `oneOf` field indices are flattened with the outer `message` field indices. On the other hand, this encoding is more compact. - -For example: - -Smithy: -```smithy - -use alloy.proto#protoInlinedOneOf - -structure Union { - @required - value: TestUnion -} - -@protoInlinedOneOf -union TestUnion { - num: Integer, - txt: String -} -``` - -Proto: -```proto -syntax = "proto3"; - -package foo; - -message Union { - oneof value { - int32 num = 1; - string txt = 2; - } -} -``` - -##### List - -Smithy: -```smithy -list StringArrayType { - member: String -} -structure StringArray { - value: StringArrayType -} -``` - -Proto: -```proto -message StringArray { - repeated string value = 1; -} -``` - -##### Map - -Smithy: -```smithy -map StringStringMapType { - key: String, - value: String -} -structure StringStringMap { - value: StringStringMapType -} -``` - -Proto: -```proto -message StringStringMap { - map value = 1; -} -``` - -#### Constraints - -##### Enum - -Smithy: -```smithy -enum Color { - RED - GREEN - BLUE -} -``` - -Proto: -```proto -enum Color { - RED = 0; - GREEN = 1; - BLUE = 2; -} -``` - -#### Service Shapes - -##### Basic Service - -Smithy: -```smithy -use alloy.proto#protoEnabled - -@protoEnabled -service FooService { - operations: [Test] -} - -@http(method: "POST", uri: "/test", code: 200) -operation Test { - input: TestInput, - output: Test200 -} - -structure InputBody { - @required - s: String -} - -structure OutputBody { - sNum: Integer -} - -structure Test200 { - @httpPayload - @required - body: OutputBody -} - -structure TestInput { - @httpPayload - @required - body: InputBody -} -``` - -Proto: -```proto -import "google/protobuf/wrappers.proto"; - -service FooService { - rpc Test(foo.TestInput) returns (foo.Test200); -} - -message InputBody { - string s = 1; -} - -message OutputBody { - google.protobuf.Int32Value sNum = 1; -} - -message Test200 { - foo.OutputBody body = 1; -} - -message TestInput { - foo.InputBody body = 1; -} -``` +The design of the smithy to protobuf translation follows the semantics defined in the [alloy specification](https://github.com/disneystreaming/alloy/blob/main/docs/serialisation/protobuf.md). ### Options @@ -1535,7 +1250,9 @@ metadata "proto_options" = [{ namespace foo -string MyString +structure Foo { + value: String +} ``` Proto: @@ -1547,7 +1264,7 @@ option java_package = "foo.pkg"; package foo; -message MyString { +message Foo { string value = 1; } ``` diff --git a/build.sc b/build.sc index e078787..870393b 100644 --- a/build.sc +++ b/build.sc @@ -301,7 +301,8 @@ trait ProtoModule def ivyDeps = super.ivyDeps() ++ Agg( buildDeps.smithy.build, buildDeps.scalapb.compilerPlugin, - buildDeps.scalapb.protocCache.withDottyCompat(scalaVersion()) + buildDeps.scalapb.protocCache.withDottyCompat(scalaVersion()), + buildDeps.alloy.protobuf ) def scalaPBVersion = buildDeps.scalapb.version diff --git a/buildDeps.sc b/buildDeps.sc index a284c2c..4263fa9 100644 --- a/buildDeps.sc +++ b/buildDeps.sc @@ -3,9 +3,11 @@ import mill.define._ import mill.scalalib._ object alloy { - val alloyVersion = "0.2.8" + val alloyVersion = "0.3.0" val core = ivy"com.disneystreaming.alloy:alloy-core:$alloyVersion" + val protobuf = + ivy"com.disneystreaming.alloy:alloy-protobuf:$alloyVersion" } object circe { val jawn = ivy"io.circe::circe-jawn:0.14.6" diff --git a/buildSetup.sc b/buildSetup.sc index 0cb6f9b..c20d67d 100644 --- a/buildSetup.sc +++ b/buildSetup.sc @@ -28,6 +28,7 @@ object ScalaVersions { } trait BaseModule extends Module with HeaderModule { + def millSourcePath: os.Path = { val originalRelativePath = super.millSourcePath.relativeTo(os.pwd) os.pwd / "modules" / originalRelativePath @@ -111,6 +112,7 @@ trait BaseScala213Module extends BaseScalaModule with ScalafmtModule { } trait BaseScalaModule extends ScalaModule with BaseModule { + override def scalacPluginIvyDeps = T { val sv = scalaVersion() val plugins = diff --git a/modules/cli/src/runners/Proto.scala b/modules/cli/src/runners/Proto.scala index dc4a3cb..7f8da9b 100644 --- a/modules/cli/src/runners/Proto.scala +++ b/modules/cli/src/runners/Proto.scala @@ -17,10 +17,11 @@ package smithytranslate.cli.runners import smithytranslate.cli.opts.ProtoOpts import smithytranslate.cli.transformer.TransformerLookup -import smithyproto.proto3.{Compiler, Renderer, ModelPreProcessor} +import smithytranslate.proto3.* import java.net.URLClassLoader import software.amazon.smithy.model.Model +import software.amazon.smithy.build.TransformContext object Proto { @@ -48,9 +49,9 @@ object Proto { } .assemble .unwrap - val model = ModelPreProcessor( - model0, - transformers ++ ModelPreProcessor.transformers.all(None) + + val model = transformers.foldLeft(model0)((m, transfomer) => + transfomer.transform(TransformContext.builder().model(m).build()) ) run(model, opts.outputPath) @@ -60,31 +61,27 @@ object Proto { */ def runForModel( model0: Model, - outputPath: os.Path, - allowedNamespace: Option[String] + outputPath: os.Path ): Unit = { val transformers = TransformerLookup.getAll() - val model = ModelPreProcessor( - model0, - transformers ++ ModelPreProcessor.transformers.all(allowedNamespace) + val model = transformers.foldLeft(model0)((m, transfomer) => + transfomer.transform(TransformContext.builder().model(m).build()) ) run(model, outputPath) } private def run(model: Model, outputPath: os.Path): Unit = { - val proto3Backend = new Compiler() - val out = proto3Backend.compile(model) + val out = SmithyToProtoCompiler.compile(model) os.walk(outputPath) .filter(p => os.isFile(p) && p.ext == "proto") .foreach(os.remove) - out.foreach { output => - val relpath = os.RelPath(output.path.toIndexedSeq, ups = 0) - val rendering = Renderer.render(output.unit) + out.foreach { case RenderedProtoFile(path, contents) => + val relpath = os.RelPath(path.toIndexedSeq, ups = 0) val outPath = outputPath / relpath os.write( outPath, - data = rendering, + data = contents, createFolders = true ) } diff --git a/modules/proto-examples/protobuf/demo/definitions.proto b/modules/proto-examples/protobuf/demo/definitions.proto index fbba064..dbb95f1 100644 --- a/modules/proto-examples/protobuf/demo/definitions.proto +++ b/modules/proto-examples/protobuf/demo/definitions.proto @@ -6,9 +6,9 @@ import "google/protobuf/empty.proto"; import "demo/common/common.proto"; -import "google/protobuf/wrappers.proto"; +import "google/protobuf/timestamp.proto"; -import "smithytranslate/definitions.proto"; +import "google/protobuf/struct.proto"; service Hello { rpc SayHello(demo.HelloRequest) returns (demo.HelloResponse); @@ -20,13 +20,13 @@ message HelloRequest { string name = 1; demo.common.Language lang = 2; int32 requiredInt = 4; - google.protobuf.Int32Value int = 5; - smithytranslate.UUID id = 6; + int32 int = 5; + string id = 6; } message HelloResponse { string message = 1; - google.protobuf.UInt64Value size = 2; + uint64 size = 2; demo.UseApiStruct apiStruct = 3; oneof apiUnion { string version = 4; @@ -36,7 +36,11 @@ message HelloResponse { } message UseApiStruct { - smithytranslate.BigInteger bigInt = 1; - smithytranslate.BigDecimal bigDec = 2; - smithytranslate.Timestamp ts = 3; + string bigInt = 1; + string bigDec = 2; + google.protobuf.Timestamp ts = 3; +} + +message DocumentWrapper { + google.protobuf.Value doc = 1; } diff --git a/modules/proto-examples/protobuf/smithytranslate/definitions.proto b/modules/proto-examples/protobuf/smithytranslate/definitions.proto deleted file mode 100644 index 79482c6..0000000 --- a/modules/proto-examples/protobuf/smithytranslate/definitions.proto +++ /dev/null @@ -1,20 +0,0 @@ -syntax = "proto3"; - -package smithytranslate; - -message BigDecimal { - string value = 1; -} - -message BigInteger { - string value = 1; -} - -message Timestamp { - int64 value = 1; -} - -message UUID { - int64 upper_bits = 1; - int64 lower_bits = 2; -} diff --git a/modules/proto-examples/smithy/demo.smithy b/modules/proto-examples/smithy/demo.smithy index d6fc963..9f2255c 100644 --- a/modules/proto-examples/smithy/demo.smithy +++ b/modules/proto-examples/smithy/demo.smithy @@ -78,3 +78,9 @@ union ApiUnion { @protoIndex(5) id: Integer } + +@protoEnabled +structure DocumentWrapper { + @required + doc: Document +} diff --git a/modules/proto-examples/src/smithyproto/scalapb/demo/HelloGrpcImpl.scala b/modules/proto-examples/src/smithyproto/scalapb/demo/HelloGrpcImpl.scala index d85715c..23d0d86 100644 --- a/modules/proto-examples/src/smithyproto/scalapb/demo/HelloGrpcImpl.scala +++ b/modules/proto-examples/src/smithyproto/scalapb/demo/HelloGrpcImpl.scala @@ -13,7 +13,7 @@ * limitations under the License. */ -package smithyproto.scalapb.demo +package smithytranslate.scalapb.demo import demo.definitions.HelloGrpc import demo.definitions.HelloRequest diff --git a/modules/proto-examples/src/smithyproto/scalapb/demo/HelloServer.scala b/modules/proto-examples/src/smithyproto/scalapb/demo/HelloServer.scala index bc32551..f4d1913 100644 --- a/modules/proto-examples/src/smithyproto/scalapb/demo/HelloServer.scala +++ b/modules/proto-examples/src/smithyproto/scalapb/demo/HelloServer.scala @@ -13,7 +13,7 @@ * limitations under the License. */ -package smithyproto.scalapb.demo +package smithytranslate.scalapb.demo import io.grpc.ServerBuilder import demo.definitions.HelloGrpc diff --git a/modules/proto/src/smithytranslate/proto3/internals/CompilerOptions.scala b/modules/proto/src/smithytranslate/proto3/RenderedProtoFile.scala similarity index 86% rename from modules/proto/src/smithytranslate/proto3/internals/CompilerOptions.scala rename to modules/proto/src/smithytranslate/proto3/RenderedProtoFile.scala index e1f3500..f2a2171 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/CompilerOptions.scala +++ b/modules/proto/src/smithytranslate/proto3/RenderedProtoFile.scala @@ -13,6 +13,6 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3 -final case class CompilerOptions() +case class RenderedProtoFile(path: List[String], contents: String) diff --git a/modules/proto/src/smithytranslate/proto3/SmithyToProtoCompiler.scala b/modules/proto/src/smithytranslate/proto3/SmithyToProtoCompiler.scala new file mode 100644 index 0000000..f6f0d9a --- /dev/null +++ b/modules/proto/src/smithytranslate/proto3/SmithyToProtoCompiler.scala @@ -0,0 +1,47 @@ +/* Copyright 2022 Disney Streaming + * + * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://disneystreaming.github.io/TOST-1.0.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package smithytranslate.proto3 + +import software.amazon.smithy.model.Model + +object SmithyToProtoCompiler + extends SmithyToProtoCompilerInterface(allShapes = false) + +class SmithyToProtoCompilerInterface private[proto3] ( + allShapes: Boolean +) { + + /** Transforms a smithy model into a list of protobuf files. + */ + def compile( + smithyModel: Model + ): List[RenderedProtoFile] = { + val compiler = + new internals.Compiler(smithyModel, allShapes = allShapes) + compiler + .compile() + .map { compileOutput => + val contents = + smithytranslate.proto3.internals.Renderer.render(compileOutput.unit) + RenderedProtoFile(compileOutput.path, contents) + } + } + + def withConvertAllShapes( + newAllShapes: Boolean + ): SmithyToProtoCompilerInterface = + new SmithyToProtoCompilerInterface(newAllShapes) +} diff --git a/modules/proto/src/smithytranslate/proto3/internals/Compiler.scala b/modules/proto/src/smithytranslate/proto3/internals/Compiler.scala index a8fa683..0ac5aa2 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/Compiler.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/Compiler.scala @@ -13,7 +13,7 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals import alloy.proto._ import smithytranslate.closure.ModelOps._ @@ -22,21 +22,61 @@ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes._ import software.amazon.smithy.model.traits.DeprecatedTrait import software.amazon.smithy.model.traits.EnumTrait -import software.amazon.smithy.model.traits.RequiredTrait -import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.UnitTypeTrait import software.amazon.smithy.model.traits.TraitDefinition import scala.jdk.CollectionConverters._ import scala.jdk.OptionConverters._ +import software.amazon.smithy.model.neighbor.NeighborProvider +import software.amazon.smithy.model.neighbor.Walker +import alloy.OpenEnumTrait +import software.amazon.smithy.model.traits.EnumValueTrait -class Compiler() { +private[proto3] class Compiler(model: Model, allShapes: Boolean) { // Reference: // 1. https://github.com/protocolbuffers/protobuf/blob/master/docs/field_presence.md import ProtoIR._ + private lazy val allRelevantShapes: Set[Shape] = { + if (allShapes) { + model + .shapes() + .iterator() + .asScala + .filterNot(ShapeFiltering.exclude) + .toSet + } else { + val walker = new Walker(NeighborProvider.of(model)) + val protoEnabledShapes = + model.getShapesWithTrait(classOf[ProtoEnabledTrait]).asScala + val grpcShapes = model.getShapesWithTrait(classOf[GrpcTrait]).asScala + val allRoots = protoEnabledShapes ++ grpcShapes + val allTransitiveShapes: Set[Shape] = allRoots + .flatMap((shape: Shape) => walker.walkShapes(shape).asScala) + .toSet + (allRoots ++ allTransitiveShapes).toSet + } + } + + private lazy val conflictingEnumValues: Set[MemberShape] = { + val enumMembers = + allRelevantShapes.collect { case m: MemberShape => m }.filter { m => + val container = model.expectShape(m.getContainer()) + container.isIntEnumShape() || container.isEnumShape() + } + def getKey(m: MemberShape) = m.getId().getNamespace() -> m.getMemberName() + val conflicting = enumMembers.groupBy(getKey).filter(_._2.size > 1).keySet + enumMembers.filter(m => conflicting(getKey(m))) + } + + private def enumValueName(m: MemberShape): String = if ( + conflictingEnumValues(m) + ) + m.getId().getName().toUpperCase + "_" + m.getMemberName() + else m.getMemberName() + /** these exclusions are performed as a last step to avoid shapes like * `structure protoEnabled {}` to be rendered as proto messages. * @@ -46,17 +86,9 @@ class Compiler() { */ object ShapeFiltering { - private val passthroughShapeIds: Set[ShapeId] = - Set( - "BigInteger", - "BigDecimal", - "Timestamp", - "UUID" - ).map(name => ShapeId.fromParts("smithytranslate", name)) private def excludeInternal(shape: Shape): Boolean = { val excludeNs = Set("alloy.proto", "alloy", "smithytranslate") - excludeNs.contains(shape.getId().getNamespace()) && - !passthroughShapeIds(shape.getId()) + excludeNs.contains(shape.getId().getNamespace()) } def traitShapes(s: Shape): Boolean = { @@ -65,17 +97,20 @@ class Compiler() { def exclude(s: Shape): Boolean = excludeInternal(s) || Prelude.isPreludeShape(s) || traitShapes(s) + + def include(s: Shape): Boolean = allShapes || allRelevantShapes(s) } - def compile(model: Model): List[OutputFile] = { + def compile(): List[OutputFile] = { val allProtocOptions = MetadataProcessor.extractProtocOptions(model) model.toShapeSet.toList + .filter(ShapeFiltering.include) .filterNot(ShapeFiltering.exclude) .groupBy(_.getId().getNamespace()) .flatMap { case (ns, shapes) => val mappings = shapes.flatMap { shape => shape - .accept(compileVisitor(model)) + .accept(topLevelDefsVisitor) .map(m => Statement.TopLevelStatement(m)) } if (mappings.nonEmpty) { @@ -128,8 +163,8 @@ class Compiler() { private def findFieldIndex(m: MemberShape): Option[Int] = m.getTrait(classOf[ProtoIndexTrait]).toScala.map(_.getNumber) - private def isRequired(m: Shape): Boolean = - m.hasTrait(classOf[RequiredTrait]) + private def hasProtoWrapped(m: Shape): Boolean = + m.hasTrait(classOf[alloy.proto.ProtoWrappedTrait]) private def isProtoService(ss: ServiceShape): Boolean = ss.hasTrait(classOf[ProtoEnabledTrait]) @@ -170,114 +205,124 @@ class Compiler() { } } - type Mappings = List[TopLevelDef] + type TopLevelDefs = List[TopLevelDef] type UnionMappings = Map[ShapeId, TopLevelDef] - private def compileVisitor(model: Model): ShapeVisitor[Mappings] = - new ShapeVisitor.Default[Mappings] { - private def topLevelMessage(shape: Shape, ty: Type) = { - val name = shape.getId.getName - val isDeprecated = shape.hasTrait(classOf[DeprecatedTrait]) - val field = - Field(deprecated = isDeprecated, ty, "value", 1) - val message = - Message(name, List(MessageElement.FieldElement(field)), Nil) - List(TopLevelDef.MessageDef(message)) - } - override def getDefault(shape: Shape): Mappings = Nil - - override def bigIntegerShape(shape: BigIntegerShape): Mappings = { - topLevelMessage(shape, Type.BigInteger) - } - - override def bigDecimalShape(shape: BigDecimalShape): Mappings = { - topLevelMessage(shape, Type.BigDecimal) - } - override def timestampShape(shape: TimestampShape): Mappings = { - topLevelMessage(shape, Type.Timestamp) - } + private object topLevelDefsVisitor + extends ShapeVisitor.Default[TopLevelDefs] { + private def topLevelMessage(shape: Shape, ty: Type) = { + val name = shape.getId.getName + val isDeprecated = shape.hasTrait(classOf[DeprecatedTrait]) + val field = + Field(deprecated = isDeprecated, ty, "value", 1) + val message = + Message(name, List(MessageElement.FieldElement(field)), Nil) + List(TopLevelDef.MessageDef(message)) + } - // TODO: streaming requests and response types - override def serviceShape(shape: ServiceShape): Mappings = - // TODO: is this the best place to do the filtering? or should it be done in a preprocessing phase - if (isProtoService(shape)) { - val operations = shape.getOperations.asScala.toList - .map(model.expectShape(_)) + private def isSimpleShape(shape: Shape): Boolean = + shape.getType().getCategory() == ShapeType.Category.SIMPLE + + override def getDefault(shape: Shape): TopLevelDefs = + if (isSimpleShape(shape) && hasProtoWrapped(shape)) { + val maybeNumType = shape + .getTrait(classOf[ProtoNumTypeTrait]) + .toScala + .map(_.getNumType()) + val maybeType = shape.accept(typeVisitor(false, maybeNumType)) + maybeType.toList.flatMap(topLevelMessage(shape, _)) + } else Nil + + // TODO: streaming requests and response types + override def serviceShape(shape: ServiceShape): TopLevelDefs = + // TODO: is this the best place to do the filtering? or should it be done in a preprocessing phase + if (isProtoService(shape)) { + val operations = shape.getOperations.asScala.toList + .map(model.expectShape(_)) + + val defs = operations.flatMap(_.accept(this)) + val rpcs = operations.flatMap(_.accept(rpcVisitor)) + val service = Service(shape.getId.getName, rpcs) + + List(TopLevelDef.ServiceDef(service)) ++ defs + } else Nil + + @annotation.nowarn( + "msg=class EnumTrait in package (.*)traits is deprecated" + ) + override def stringShape(shape: StringShape): TopLevelDefs = { + val name = shape.getId.getName + getEnumTrait(shape).map { (et: EnumTrait) => + val reserved = getReservedValues(shape) + val elements = et + .getValues() + .asScala + .toList + .zipWithIndex + .map { case (ed, edFieldNumber) => + val eName = ed + .getName() + .toScala + .getOrElse( + sys.error( + s"Error on shape: ${shape.getId()}: `enum` should have `name` defined." + ) + ) - val defs = operations.flatMap(_.accept(this)) - val rpcs = operations.flatMap(_.accept(rpcVisitor)) - val service = Service(shape.getId.getName, rpcs) + EnumValue(eName, edFieldNumber) + } - List(TopLevelDef.ServiceDef(service)) ++ defs + List(TopLevelDef.EnumDef(Enum(name, elements, reserved))) + } getOrElse { + if (shape.hasTrait(classOf[ProtoWrappedTrait])) { + topLevelMessage(shape, Type.String) } else Nil - - override def booleanShape(shape: BooleanShape): Mappings = { - topLevelMessage(shape, Type.Bool) - } - - override def blobShape(shape: BlobShape): Mappings = { - topLevelMessage(shape, Type.Bytes) - } - - override def integerShape(shape: IntegerShape): Mappings = { - topLevelMessage(shape, Type.Int32) - } - - override def longShape(shape: LongShape): Mappings = { - topLevelMessage(shape, Type.Int64) - } - - override def doubleShape(shape: DoubleShape): Mappings = { - topLevelMessage(shape, Type.Double) - } - - override def shortShape(shape: ShortShape): Mappings = { - topLevelMessage(shape, Type.Int32) - } - - override def floatShape(shape: FloatShape): Mappings = { - topLevelMessage(shape, Type.Float) - } - - override def documentShape(shape: DocumentShape): Mappings = { - topLevelMessage(shape, Type.Any) } + } - override def stringShape(shape: StringShape): Mappings = { - val name = shape.getId.getName - getEnumTrait(shape).map { et => - val reserved = getReservedValues(shape) - val elements = et - .getValues() - .asScala - .toList - .zipWithIndex - .map { case (ed, edFieldNumber) => - val eName = ed - .getName() - .toScala - .getOrElse( - sys.error( - s"Error on shape: ${shape.getId()}: `enum` should have `name` defined." - ) - ) + private def shouldWrapCollection(shape: Shape): Boolean = { + val hasWrapped = hasProtoWrapped(shape) + val membersTargetingThis = + model.getMemberShapes().asScala.filter(_.getTarget() == shape.getId()) + val isTargetedByWrappedMember = + membersTargetingThis.exists(hasProtoWrapped(_)) + // oneofs cannot have lists / maps fields + val isTargetedByUnionMember = + membersTargetingThis.exists(member => + model.expectShape(member.getContainer()).isUnionShape + ) - EnumValue(eName, edFieldNumber) - } + hasWrapped || isTargetedByWrappedMember || isTargetedByUnionMember + } - List(TopLevelDef.EnumDef(Enum(name, elements, reserved))) - } getOrElse { - topLevelMessage(shape, Type.String) + override def listShape(shape: ListShape): TopLevelDefs = { + if (shouldWrapCollection(shape)) { + shape.getMember().accept(typeVisitor()).toList.flatMap { tpe => + topLevelMessage(shape, Type.ListType(tpe)) } - } + } else Nil + } - override def enumShape(shape: EnumShape): Mappings = { + override def mapShape(shape: MapShape): TopLevelDefs = { + if (shouldWrapCollection(shape)) { + for { + keyType <- shape.getKey().accept(typeVisitor()).toList + valueType <- shape.getValue().accept(typeVisitor()).toList + result <- topLevelMessage(shape, Type.MapType(keyType, valueType)) + } yield result + } else Nil + } + + override def enumShape(shape: EnumShape): TopLevelDefs = { + if (shape.hasTrait(classOf[OpenEnumTrait])) { + Nil + } else { val reserved: List[Reserved] = getReservedValues(shape) val elements: List[EnumValue] = - shape.getAllMembers.asScala.toList.zipWithIndex - .map { case ((name, member), edFieldNumber) => + shape.members.asScala.toList.zipWithIndex + .map { case (member, edFieldNumber) => val fieldIndex = findFieldIndex(member).getOrElse(edFieldNumber) - EnumValue(name, fieldIndex) + EnumValue(enumValueName(member), fieldIndex) } List( TopLevelDef.EnumDef( @@ -285,12 +330,22 @@ class Compiler() { ) ) } + } - override def intEnumShape(shape: IntEnumShape): Mappings = { + override def intEnumShape(shape: IntEnumShape): TopLevelDefs = { + if (shape.hasTrait(classOf[OpenEnumTrait])) { + Nil + } else { val reserved: List[Reserved] = getReservedValues(shape) - val elements = shape.getEnumValues.asScala.toList.map { - case (name, value) => - EnumValue(name, value) + val elements = shape.members.asScala.toList.map { member => + val enumValue = + member.expectTrait(classOf[EnumValueTrait]).expectIntValue() + val protoIndex = member + .getTrait(classOf[ProtoIndexTrait]) + .toScala + .map(_.getNumber()) + .getOrElse(enumValue) + EnumValue(enumValueName(member), protoIndex) } List( TopLevelDef.EnumDef( @@ -298,99 +353,111 @@ class Compiler() { ) ) } + } - private def unionShouldBeInlined(shape: UnionShape): Boolean = { - shape.hasTrait(classOf[alloy.proto.ProtoInlinedOneOfTrait]) - } + private def unionShouldBeInlined(shape: UnionShape): Boolean = { + shape.hasTrait(classOf[alloy.proto.ProtoInlinedOneOfTrait]) + } - override def unionShape(shape: UnionShape): Mappings = { - if (!unionShouldBeInlined(shape)) { - val element = - MessageElement.OneofElement(processUnion("definition", shape, 1)) - val name = shape.getId.getName - val reserved = getReservedValues(shape) - val message = Message(name, List(element), reserved) - List(TopLevelDef.MessageDef(message)) - } else { - List.empty - } + override def unionShape(shape: UnionShape): TopLevelDefs = { + if (!unionShouldBeInlined(shape)) { + val element = + MessageElement.OneofElement(processUnion("definition", shape, 1)) + val name = shape.getId.getName + val reserved = getReservedValues(shape) + val message = Message(name, List(element), reserved) + List(TopLevelDef.MessageDef(message)) + } else { + List.empty } + } - override def structureShape(shape: StructureShape): Mappings = { - val name = shape.getId.getName - val messageElements = - shape.members.asScala.toList - // using foldLeft to accumulate the field count when we fork to - // process a union - .foldLeft((List.empty[MessageElement], 0)) { - case ((fields, fieldCount), m) => - val fieldName = m.getMemberName - val fieldIndex = findFieldIndex(m).getOrElse(fieldCount + 1) - // We assume the model is well-formed so the result should be non-null - val targetShape = model.getShape(m.getTarget).get - targetShape - .asUnionShape() - .toScala - .filter(unionShape => unionShouldBeInlined(unionShape)) - .map { union => - val field = MessageElement.OneofElement( - processUnion(fieldName, union, fieldIndex) - ) - (fields :+ field, fieldCount + field.oneof.fields.size) - } - .getOrElse { - val isDeprecated = m.hasTrait(classOf[DeprecatedTrait]) - val isBoxed = isRequired(m) || isRequired(targetShape) - val numType = extractNumType(m) - val fieldType = + override def structureShape(shape: StructureShape): TopLevelDefs = { + val name = shape.getId.getName + val messageElements = + shape.members.asScala.toList + // using foldLeft to accumulate the field count when we fork to + // process a union + .foldLeft((List.empty[MessageElement], 0)) { + case ((fields, fieldCount), m) => + val fieldName = m.getMemberName + val fieldIndex = findFieldIndex(m).getOrElse(fieldCount + 1) + val targetShape = model.expectShape(m.getTarget) + targetShape + .asUnionShape() + .toScala + .filter(unionShape => unionShouldBeInlined(unionShape)) + .map { union => + val field = MessageElement.OneofElement( + processUnion(fieldName, union, fieldIndex) + ) + (fields :+ field, fieldCount + field.oneof.fields.size) + } + .getOrElse { + val isDeprecated = m.hasTrait(classOf[DeprecatedTrait]) + val fieldType = + if (hasProtoWrapped(targetShape)) { + Type.RefType(targetShape) + } else { + val numType = extractNumType(m) + val wrapped = hasProtoWrapped(m) targetShape - .accept(typeVisitor(model, isBoxed, numType)) + .accept(typeVisitor(wrapped, numType)) .get - val field = MessageElement.FieldElement( - Field( - deprecated = isDeprecated, - fieldType, - fieldName, - fieldIndex - ) + } + val field = MessageElement.FieldElement( + Field( + deprecated = isDeprecated, + fieldType, + fieldName, + fieldIndex ) - (fields :+ field, fieldCount + 1) - } - } - ._1 + ) + (fields :+ field, fieldCount + 1) + } + } + ._1 - val reserved = getReservedValues(shape) - val message = Message(name, messageElements, reserved) - List(TopLevelDef.MessageDef(message)) - } + val reserved = getReservedValues(shape) + val message = Message(name, messageElements, reserved) + List(TopLevelDef.MessageDef(message)) + } - private def processUnion( - name: String, - shape: UnionShape, - indexStart: Int - ): Oneof = { - val fields = shape.members.asScala.toList.zipWithIndex.map { - case (m, fn) => - val fieldName = m.getMemberName - val fieldIndex = findFieldIndex(m).getOrElse(indexStart + fn) - // We assume the model is well-formed so the result should be non-null - val targetShape = model.getShape(m.getTarget).get - val numType = extractNumType(m) - val fieldType = - targetShape - .accept(typeVisitor(model, isRequired = true, numType)) - .get - val isDeprecated = m.hasTrait(classOf[DeprecatedTrait]) - Field( - deprecated = isDeprecated, - fieldType, - fieldName, - fieldIndex - ) - } - Oneof(name, fields) + private def processUnion( + name: String, + shape: UnionShape, + indexStart: Int + ): Oneof = { + val fields = shape.members.asScala.toList.zipWithIndex.map { + case (m, fn) => + val fieldName = m.getMemberName + val fieldIndex = findFieldIndex(m).getOrElse(indexStart + fn) + // We assume the model is well-formed so the result should be non-null + val targetShape = model.expectShape(m.getTarget) + val numType = extractNumType(m) + val isWrapped = { + val memberHasWrapped = hasProtoWrapped(m) + val targetHasWrapped = hasProtoWrapped(targetShape) + // repeated / map fields cannot be in oneofs + val isList = targetShape.isListShape() + val isMap = targetShape.isMapShape() + memberHasWrapped || targetHasWrapped || isList || isMap + } + val fieldType = + targetShape + .accept(typeVisitor(isWrapped = isWrapped, numType)) + .get + val isDeprecated = m.hasTrait(classOf[DeprecatedTrait]) + Field( + deprecated = isDeprecated, + fieldType, + fieldName, + fieldIndex + ) } + Oneof(name, fields) } + } private def getReservedValues(shape: Shape): List[Reserved] = shape @@ -407,30 +474,29 @@ class Compiler() { // TODO: collisions in synthesized name - private def rpcVisitor: ShapeVisitor[Option[Rpc]] = - new ShapeVisitor.Default[Option[Rpc]] { - override def getDefault(shape: Shape): Option[Rpc] = None - override def operationShape(shape: OperationShape): Option[Rpc] = { - val maybeInputShapeId = shape.getInput() - val outputShapeId = shape.getOutput().get() - val request = maybeInputShapeId.toScala - .map { inputShapeId => - RpcMessage( - Namespacing.shapeIdToFqn(inputShapeId), - Namespacing.namespaceToFqn(inputShapeId.getNamespace()) - ) - } - .getOrElse { - RpcMessage(Type.Empty.fqn, Type.Empty.fqn) - } + private object rpcVisitor extends ShapeVisitor.Default[Option[Rpc]] { + override def getDefault(shape: Shape): Option[Rpc] = None + override def operationShape(shape: OperationShape): Option[Rpc] = { + val maybeInputShapeId = shape.getInput() + val outputShapeId = shape.getOutput().get() + val request = maybeInputShapeId.toScala + .map { inputShapeId => + RpcMessage( + Namespacing.shapeIdToFqn(inputShapeId), + Namespacing.namespaceToFqn(inputShapeId.getNamespace()) + ) + } + .getOrElse { + RpcMessage(Type.Empty.fqn, Type.Empty.fqn) + } - val response = RpcMessage( - Namespacing.shapeIdToFqn(outputShapeId), - Namespacing.namespaceToFqn(outputShapeId.getNamespace()) - ) - Some(Rpc(shape.getId.getName, false, request, false, response)) - } + val response = RpcMessage( + Namespacing.shapeIdToFqn(outputShapeId), + Namespacing.namespaceToFqn(outputShapeId.getNamespace()) + ) + Some(Rpc(shape.getId.getName, false, request, false, response)) } + } private def extractNumType( shape: Shape @@ -441,12 +507,6 @@ class Compiler() { .map { _.getNumType() } } - private def isSparse(shape: Shape): Boolean = { - shape - .getTrait(classOf[SparseTrait]) - .isPresent() - } - private def isUnit(shape: StructureShape): Boolean = { shape .getTrait(classOf[UnitTypeTrait]) @@ -458,210 +518,124 @@ class Compiler() { // https://awslabs.github.io/smithy/1.0/spec/core/model.html#simple-shapes // TODO: namespace in type? private def typeVisitor( - model: Model, - isRequired: Boolean, - numType: Option[ProtoNumTypeTrait.NumType] + isWrapped: Boolean = false, + numType: Option[ProtoNumTypeTrait.NumType] = None ): ShapeVisitor[Option[Type]] = new ShapeVisitor[Option[Type]] { - def bigDecimalShape(shape: BigDecimalShape): Option[Type] = Some({ - if (Prelude.isPreludeShape(shape.getId())) { - Type.BigDecimal - } else { - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - } - }) - def bigIntegerShape(shape: BigIntegerShape): Option[Type] = Some({ - if (Prelude.isPreludeShape(shape.getId())) { - Type.BigInteger - } else { - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - } - }) - def blobShape(shape: BlobShape): Option[Type] = Some( - if (isRequired && Prelude.isPreludeShape(shape.getId())) { - Type.Bytes - } else if (Prelude.isPreludeShape(shape.getId())) { - Type.Wrappers.Bytes - } else { - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - } - ) - def booleanShape(shape: BooleanShape): Option[Type] = Some( - if (isRequired && Prelude.isPreludeShape(shape.getId())) { - Type.Bool - } else if (Prelude.isPreludeShape(shape.getId())) { - Type.Wrappers.Bool - } else { - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - } - ) - def byteShape(shape: ByteShape): Option[Type] = - if (Prelude.isPreludeShape(shape.getId())) { - Some(NumberType.resolveInt(isRequired, numType)) - } else { - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - } - def documentShape(shape: DocumentShape): Option[Type] = - Some(Type.Any) - def doubleShape(shape: DoubleShape): Option[Type] = - if (isRequired && Prelude.isPreludeShape(shape.getId())) - Some(Type.Double) - else if (Prelude.isPreludeShape(shape.getId())) - Some(Type.Wrappers.Double) - else - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - def floatShape(shape: FloatShape): Option[Type] = - if (isRequired && Prelude.isPreludeShape(shape.getId())) - Some(Type.Float) - else if (Prelude.isPreludeShape(shape.getId())) - Some(Type.Wrappers.Float) - else - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - def integerShape(shape: IntegerShape): Option[Type] = { - if (Prelude.isPreludeShape(shape.getId())) { - Some(NumberType.resolveInt(isRequired, numType)) - } else { - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - } + def bigDecimalShape(shape: BigDecimalShape): Option[Type] = Some { + if (!isWrapped) Type.String + else Type.AlloyWrappers.BigDecimal } - def listShape(shape: ListShape): Option[Type] = { - val memberShape = model.getShape(shape.getMember().getTarget()).get - // to do sparse & numtype - memberShape - .accept( - typeVisitor(model, isRequired = !isSparse(shape), numType = None) - ) - .map(Type.ListType(_)) + def bigIntegerShape(shape: BigIntegerShape): Option[Type] = Some { + if (!isWrapped) Type.String + else Type.AlloyWrappers.BigInteger } - def longShape(shape: LongShape): Option[Type] = { - if (Prelude.isPreludeShape(shape.getId())) { - Some(NumberType.resolveLong(isRequired, numType)) - } else { - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - } + def blobShape(shape: BlobShape): Option[Type] = Some { + if (!isWrapped) Type.Bytes + else Type.GoogleWrappers.Bytes + } + def booleanShape(shape: BooleanShape): Option[Type] = Some { + if (!isWrapped) Type.Bool + else Type.GoogleWrappers.Bool + } + def byteShape(shape: ByteShape): Option[Type] = Some { + if (!isWrapped) Type.Int32 + else Type.AlloyWrappers.ByteValue } + def documentShape(shape: DocumentShape): Option[Type] = Some { + if (!isWrapped) Type.GoogleValue + else Type.AlloyWrappers.Document + } + + def doubleShape(shape: DoubleShape): Option[Type] = Some { + if (!isWrapped) Type.Double + else Type.GoogleWrappers.Double + } + def floatShape(shape: FloatShape): Option[Type] = Some { + if (!isWrapped) Type.Float + else Type.GoogleWrappers.Float + } + def shortShape(shape: ShortShape): Option[Type] = Some { + if (!isWrapped) Type.Int32 + else Type.AlloyWrappers.ShortValue + } + def integerShape(shape: IntegerShape): Option[Type] = Some { + NumberType.resolveInt(isWrapped, numType) + } + def longShape(shape: LongShape): Option[Type] = Some { + NumberType.resolveLong(isWrapped, numType) + } + + def listShape(shape: ListShape): Option[Type] = { + if (isWrapped) Some(Type.RefType(shape)) + else shape.getMember().accept(typeVisitor()).map(Type.ListType(_)) + } + def mapShape(shape: MapShape): Option[Type] = { - for { - valueShape <- model.getShape(shape.getValue.getTarget).toScala - valueType <- valueShape.accept( - typeVisitor(model, isRequired = !isSparse(shape), numType = None) - ) - } yield Type.MapType(Right(Type.String), valueType) + if (isWrapped) Some(Type.RefType(shape)) + else + for { + key <- shape.getKey().accept(typeVisitor()) + value <- shape.getValue().accept(typeVisitor()) + } yield Type.MapType(key, value) + } + + def memberShape(shape: MemberShape): Option[Type] = { + val target = model.expectShape(shape.getTarget()) + val memberHasWrapped = shape.hasTrait(classOf[ProtoWrappedTrait]) + val targetHasWrapped = target.hasTrait(classOf[ProtoWrappedTrait]) + val isWrapped = memberHasWrapped || targetHasWrapped + val numType = + shape + .getTrait(classOf[ProtoNumTypeTrait]) + .or(() => target.getTrait(classOf[ProtoNumTypeTrait])) + .toScala + .map(_.getNumType()) + + target.accept(typeVisitor(isWrapped, numType)) } - def memberShape(shape: MemberShape): Option[Type] = None + def operationShape(shape: OperationShape): Option[Type] = None def resourceShape(shape: ResourceShape): Option[Type] = None def serviceShape(shape: ServiceShape): Option[Type] = None - @annotation.nowarn( - "msg=class SetShape in package (.*)shapes is deprecated" - ) - override def setShape(shape: SetShape): Option[Type] = Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - def shortShape(shape: ShortShape): Option[Type] = - if (Prelude.isPreludeShape(shape.getId())) { - Some(NumberType.resolveInt(isRequired, numType)) - } else { - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - } - // TODO: we are diverging from the spec here - def stringShape(shape: StringShape): Option[Type] = Some( - if (isRequired && Prelude.isPreludeShape(shape.getId())) { - Type.String - } else if (Prelude.isPreludeShape(shape.getId())) { - Type.Wrappers.String - } else { - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - } - ) - override def enumShape(shape: EnumShape): Option[Type] = Some( - Type.EnumType( - Namespacing.shapeIdToFqn(shape.getId()), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - override def intEnumShape(shape: IntEnumShape): Option[Type] = Some( - Type.EnumType( - Namespacing.shapeIdToFqn(shape.getId()), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) - def structureShape(shape: StructureShape): Option[Type] = { - if (isUnit(shape)) { - Some(Type.Empty) + def stringShape(shape: StringShape): Option[Type] = Some { + val hasUUIDFormat = shape.hasTrait(classOf[alloy.UuidFormatTrait]) + val hasProtoCompactUUID = + shape.hasTrait(classOf[alloy.proto.ProtoCompactUUIDTrait]) + if (hasUUIDFormat && hasProtoCompactUUID) Type.AlloyTypes.CompactUUID + else if (!isWrapped) Type.String + else Type.GoogleWrappers.String + } + override def enumShape(shape: EnumShape): Option[Type] = { + if (shape.hasTrait(classOf[OpenEnumTrait])) { + Some(Type.String) } else { - Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) - ) + Some(Type.RefType(shape)) } } - def timestampShape(shape: TimestampShape): Option[Type] = Some( - if (Prelude.isPreludeShape(shape.getId())) { - Type.Timestamp + override def intEnumShape(shape: IntEnumShape): Option[Type] = { + if (shape.hasTrait(classOf[OpenEnumTrait])) { + Some(Type.Int32) } else { - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) + Some(Type.RefType(shape)) } - ) + } + + def structureShape(shape: StructureShape): Option[Type] = Some { + if (isUnit(shape)) + Type.Empty + else + Type.RefType(shape) + } + + def timestampShape(shape: TimestampShape): Option[Type] = Some { + if (!isWrapped) Type.GoogleTimestamp + else Type.AlloyWrappers.Timestamp + } + def unionShape(shape: UnionShape): Option[Type] = Some( - Type.MessageType( - Namespacing.shapeIdToFqn(shape.getId), - Namespacing.shapeIdToImportFqn(shape.getId()) - ) + Type.RefType(shape) ) } @@ -671,37 +645,43 @@ class Compiler() { private object NumberType { def resolveLong( - isRequired: Boolean, + isWrapped: Boolean, maybeNumType: Option[ProtoNumTypeTrait.NumType] ): Type = { import ProtoNumTypeTrait.NumType._ - (isRequired, maybeNumType) match { - case (true, Some(SIGNED)) => Type.Sint64 - case (true, Some(UNSIGNED)) => Type.Uint64 - case (true, Some(FIXED)) => Type.Fixed64 - case (true, Some(FIXED_SIGNED)) => Type.Sfixed64 - case (true, Some(UNKNOWN)) => Type.Int64 - case (true, None) => Type.Int64 - case (false, Some(UNSIGNED)) => Type.Wrappers.Uint64 - case (false, Some(_)) => Type.Wrappers.Int64 - case (false, None) => Type.Wrappers.Int64 + (isWrapped, maybeNumType) match { + case (false, Some(SIGNED)) => Type.Sint64 + case (false, Some(UNSIGNED)) => Type.Uint64 + case (false, Some(FIXED)) => Type.Fixed64 + case (false, Some(FIXED_SIGNED)) => Type.Sfixed64 + case (false, Some(UNKNOWN)) => Type.Int64 + case (false, None) => Type.Int64 + case (true, Some(SIGNED)) => Type.AlloyWrappers.SInt64 + case (true, Some(UNSIGNED)) => Type.GoogleWrappers.Uint64 + case (true, Some(FIXED)) => Type.AlloyWrappers.Fixed64 + case (true, Some(FIXED_SIGNED)) => Type.AlloyWrappers.SFixed64 + case (true, Some(UNKNOWN)) => Type.GoogleWrappers.Int64 + case (true, None) => Type.GoogleWrappers.Int64 } } def resolveInt( - isRequired: Boolean, + isWrapped: Boolean, maybeNumType: Option[ProtoNumTypeTrait.NumType] ): Type = { import ProtoNumTypeTrait.NumType._ - (isRequired, maybeNumType) match { - case (true, Some(SIGNED)) => Type.Sint32 - case (true, Some(UNSIGNED)) => Type.Uint32 - case (true, Some(FIXED)) => Type.Fixed32 - case (true, Some(FIXED_SIGNED)) => Type.Sfixed32 - case (true, Some(UNKNOWN)) => Type.Int32 - case (true, None) => Type.Int32 - case (false, Some(UNSIGNED)) => Type.Wrappers.Uint32 - case (false, Some(_)) => Type.Wrappers.Int32 - case (false, None) => Type.Wrappers.Int32 + (isWrapped, maybeNumType) match { + case (false, Some(SIGNED)) => Type.Sint32 + case (false, Some(UNSIGNED)) => Type.Uint32 + case (false, Some(FIXED)) => Type.Fixed32 + case (false, Some(FIXED_SIGNED)) => Type.Sfixed32 + case (false, Some(UNKNOWN)) => Type.Int32 + case (false, None) => Type.Int32 + case (true, Some(SIGNED)) => Type.AlloyWrappers.SInt32 + case (true, Some(UNSIGNED)) => Type.GoogleWrappers.Uint32 + case (true, Some(FIXED)) => Type.AlloyWrappers.Fixed32 + case (true, Some(FIXED_SIGNED)) => Type.AlloyWrappers.SFixed32 + case (true, Some(UNKNOWN)) => Type.GoogleWrappers.Int32 + case (true, None) => Type.GoogleWrappers.Int32 } } } diff --git a/modules/proto/src/smithytranslate/proto3/internals/MetadataProcessor.scala b/modules/proto/src/smithytranslate/proto3/internals/MetadataProcessor.scala index 70de803..d562ef2 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/MetadataProcessor.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/MetadataProcessor.scala @@ -13,13 +13,13 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals import software.amazon.smithy.model.Model import scala.jdk.OptionConverters._ import scala.jdk.CollectionConverters._ -object MetadataProcessor { +private[internals] object MetadataProcessor { type ProtocOptions = Map[String, Map[String, String]] diff --git a/modules/proto/src/smithytranslate/proto3/internals/ModelPreProcessor.scala b/modules/proto/src/smithytranslate/proto3/internals/ModelPreProcessor.scala deleted file mode 100644 index c73e86d..0000000 --- a/modules/proto/src/smithytranslate/proto3/internals/ModelPreProcessor.scala +++ /dev/null @@ -1,308 +0,0 @@ -/* Copyright 2022 Disney Streaming - * - * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://disneystreaming.github.io/TOST-1.0.txt - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package smithyproto.proto3 - -import java.util.stream.Collectors -import smithytranslate.closure.TransitiveModel -import smithytranslate.UUID -import software.amazon.smithy.build.{ProjectionTransformer, TransformContext} -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.loader.Prelude -import software.amazon.smithy.model.shapes._ - -import java.util -import scala.jdk.CollectionConverters._ -import scala.collection.compat._ - -object ModelPreProcessor { - - object transformers { - object Transitive { - def apply(allowedNamespace: Option[String]) = - new ProjectionTransformer() { - def getName(): String = "transitive-filtering" - def transform(x: TransformContext): Model = { - val annotatedShapes = x - .getModel() - .getShapesWithTrait(classOf[alloy.proto.ProtoEnabledTrait]) - .asScala - .map(_.getId()) - .filter(id => allowedNamespace.forall(_ == id.getNamespace())) - .toList - if (annotatedShapes.size < 1) { - System.err.println( - s"No shapes annotated with ${alloy.proto.ProtoEnabledTrait.ID} were found." - ) - } - TransitiveModel.compute( - x.getModel(), - annotatedShapes, - captureTraits = true, - captureMetadata = true, - validateModel = - false // model may be in invalid state since it is in a transient/intermediary state of the proto conversion - ) - } - } - } - - /** This ProjectionTransformer is used to introduce shapes that are part of - * the `smithytranslate` and used to replace Prelude shapes. These shapes - * are: - * - BigInteger - * - BigDecimal - * - Timestamp - * @param original - * @return - */ - val PreludeReplacements = new ProjectionTransformer() { - // Prelude.getPreludeModel is not accessible - private val preludeModel = Model.assembler().assemble().unwrap() - private val addIfUsed = Map( - // format: off - (classOf[BigIntegerShape], (smithytranslate.BigInteger.shape, smithytranslate.BigInteger.target)), - (classOf[BigDecimalShape], (smithytranslate.BigDecimal.shape, smithytranslate.BigDecimal.target)), - (classOf[TimestampShape], (smithytranslate.Timestamp.shape, smithytranslate.Timestamp.target)) - // format: on - ) - - def getName(): String = "prelude-replacements" - def transform(x: TransformContext): Model = { - val m = x.getModel() - val toAdd = - addIfUsed.flatMap { case (clazz, (shape, preludeShapeId)) => - if (m.toSet(clazz).size() > 0) { - List(shape, preludeModel.expectShape(preludeShapeId)) - } else List.empty - }.toList - - m.toBuilder() - .addShapes(toAdd.asJava) - .build() - } - } - - /** Conflicts for enum happens at the value level on the protobuf side. Two - * different enums in a same package can't have the same value. To catch - * this, we build a map of EnumDefinition -> Boolean where the right side - * is true if there is a conflict, false otherwise. - * - * To build the map, we select all EnumShapes. Then we build an - * intermediate map of all the resolved protobuf enum name (see - * #protoEnumName) by namespace. We use this to build the final lookup map - * where we run (eagerly) a conflict check for each Value of the Enum. - * found in the model. - */ - - val PreventEnumConflicts: ProjectionTransformer = - new ProjectionTransformer() { - - def getName(): String = "prevent-enum-conflicts" - def transform(x: TransformContext): Model = { - val currentModel = x.getModel - val enumsShapes: List[EnumShape] = currentModel - .getEnumShapes() - .asScala - .filterNot(Prelude.isPreludeShape) - .toList - - val intEnums: List[IntEnumShape] = currentModel - .getIntEnumShapes() - .asScala - .filterNot(Prelude.isPreludeShape) - .toList - - val allEnums: List[Shape] = enumsShapes ++ intEnums - - val allCombos = for { - e <- allEnums - memberName <- e.getMemberNames().asScala.toList - } yield (e.getId().getNamespace(), memberName) - - val allRepeatedCombos = - allCombos - .groupBy(identity) - .view - .mapValues(_.size) - .collect { - case (k, v) if v > 1 => k - } - .toSet - - def hasConflict(member: MemberShape): Boolean = allRepeatedCombos( - (member.getId().getNamespace(), member.getMemberName()) - ) - - val newEnumShapes: List[Shape] = enumsShapes.map { enumShape => - val b = enumShape.toBuilder - b.clearMembers() - enumShape.members.asScala.foreach { - case member if hasConflict(member) => - b.addMember(renameMember(member)) - case member => - b.addMember(member) - } - b.build() - } - - val newIntEnumShapes = intEnums.map { intEnumShape => - val b = intEnumShape.toBuilder - b.clearMembers() - intEnumShape.members.asScala.foreach { - case member if hasConflict(member) => - b.addMember(renameMember(member)) - case member => - b.addMember(member) - } - b.build() - } - - val allShapes = newEnumShapes ++ newIntEnumShapes - - x.getTransformer() - .replaceShapes( - currentModel, - allShapes.asJava - ) - } - - def renameMember(member: MemberShape): MemberShape = { - val name = - s"${member.getId.getName().toUpperCase()}_${member.getMemberName}" - member.toBuilder - .id(member.getId.withMember(name)) - .build() - } - } - - /** Transforms UUID into a structure that produces the following protobuf - * message: - * ```proto - * message UUID { - * int64 upper_bits = 1; - * int64 lower_bits = 2; - * } - * ``` - */ - val CompactUUID: ProjectionTransformer = - new ProjectionTransformer() { - - def getName(): String = "compact-alloy-uuid" - def transform(x: TransformContext): Model = { - val uuidShapeId = ShapeId.fromParts("alloy", "UUID") - val newUUIDShapeId = ShapeId.fromParts("smithytranslate", "UUID") - - /* Visitor to replace any reference to alloy#UUID in member shapes to - * a custom alloy#CompactUUID shape. - */ - val updateMemberShapes = new ShapeVisitor.Default[Shape]() { - override protected def getDefault(shape: Shape): Shape = - shape - - private def updateMember(shape: MemberShape): MemberShape = { - if (shape.getTarget() == uuidShapeId) { - shape.toBuilder().target(newUUIDShapeId).build() - } else { - shape - } - } - - override def structureShape(shape: StructureShape): Shape = { - shape - .toBuilder() - .members( - shape - .getAllMembers() - .values() - .stream() - .map[MemberShape](updateMember) - .collect(Collectors.toList()) - ) - .build() - - } - override def unionShape(shape: UnionShape): Shape = { - shape - .toBuilder() - .members( - shape - .getAllMembers() - .values() - .stream() - .map[MemberShape](updateMember) - .collect(Collectors.toList()) - ) - .build() - } - override def listShape(shape: ListShape): Shape = { - shape - .toBuilder() - .member(updateMember(shape.getMember())) - .build() - } - override def mapShape(shape: MapShape): Shape = { - shape - .toBuilder() - .key(updateMember(shape.getKey())) - .value(updateMember(shape.getValue())) - .build() - } - } - val uuidUsage = x - .getModel() - .getMemberShapes() - .stream() - .filter { _.getTarget() == uuidShapeId } - .count() - if (uuidUsage > 0) { - val updatedShapes: util.List[Shape] = x - .getModel() - .toSet() - .stream() - // remove reference to alloy#UUID - .filter(_.getId() != uuidShapeId) - .map[Shape] { _shape => - _shape.accept[Shape](updateMemberShapes) - } - .collect(Collectors.toList()) - Model - .builder() - .addShapes(updatedShapes) - .addShape(UUID.shape) - .build() - } else { - x.getModel() - } - } - } - - def all(allowedNamespace: Option[String]): List[ProjectionTransformer] = - Transitive(allowedNamespace) :: - PreludeReplacements :: - PreventEnumConflicts :: - CompactUUID :: - Nil - } - - def apply( - model: Model, - transformers: List[ProjectionTransformer] - ): Model = { - transformers.foldLeft(model) { (acc, transformer) => - transformer.transform(TransformContext.builder().model(acc).build()) - } - } -} diff --git a/modules/proto/src/smithytranslate/proto3/internals/Namespacing.scala b/modules/proto/src/smithytranslate/proto3/internals/Namespacing.scala index 00122f4..5126c04 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/Namespacing.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/Namespacing.scala @@ -13,12 +13,12 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals import software.amazon.smithy.model.shapes.ShapeId -import smithyproto.proto3.ProtoIR.Fqn +import ProtoIR.Fqn -object Namespacing { +private[internals] object Namespacing { def shapeIdToFqn(id: ShapeId): Fqn = Fqn(Some(namespaceToPackage(id.getNamespace)), id.getName) diff --git a/modules/proto/src/smithytranslate/proto3/internals/OutputFile.scala b/modules/proto/src/smithytranslate/proto3/internals/OutputFile.scala index 565302d..205a2fe 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/OutputFile.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/OutputFile.scala @@ -13,8 +13,11 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals -import smithyproto.proto3.ProtoIR.CompilationUnit +import ProtoIR.CompilationUnit -final case class OutputFile(path: List[String], unit: CompilationUnit) +private[internals] final case class OutputFile( + path: List[String], + unit: CompilationUnit +) diff --git a/modules/proto/src/smithytranslate/proto3/internals/ProtoIR.scala b/modules/proto/src/smithytranslate/proto3/internals/ProtoIR.scala index 2ff520c..271956e 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/ProtoIR.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/ProtoIR.scala @@ -13,9 +13,11 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals -object ProtoIR { +import software.amazon.smithy.model.shapes.ToShapeId + +private[internals] object ProtoIR { final case class CompilationUnit( packageName: Option[String], @@ -90,9 +92,13 @@ object ProtoIR { def importFqn: Set[Fqn] } object Type { + private def protobufFqn(last: String) = Fqn(Some(List("google", "protobuf")), last) + private def alloyFqn(last: String) = + Fqn(Some(List("alloy", "protobuf")), last) + sealed trait PrimitiveType extends Type { def importFqn: Set[Fqn] = Set.empty } @@ -113,21 +119,23 @@ object ProtoIR { case object String extends PrimitiveType case object Bytes extends PrimitiveType final case class MapType( - keyType: Either[Type.Int32.type, Type.String.type], + keyType: Type, valueType: Type ) extends Type { - val foldedKeyType: Type = keyType.fold(identity, identity) def importFqn: Set[Fqn] = - keyType.fold(_.importFqn, _.importFqn) ++ valueType.importFqn + keyType.importFqn ++ valueType.importFqn } final case class ListType(valueType: Type) extends Type { def importFqn: Set[Fqn] = valueType.importFqn } - final case class MessageType(fqn: Fqn, _importFqn: Fqn) extends Type { + final case class RefType(fqn: Fqn, _importFqn: Fqn) extends Type { def importFqn: Set[Fqn] = Set(_importFqn) } - final case class EnumType(fqn: Fqn, _importFqn: Fqn) extends Type { - def importFqn: Set[Fqn] = Set(_importFqn) + object RefType { + def apply(toShapeId: ToShapeId): RefType = RefType( + Namespacing.shapeIdToFqn(toShapeId.toShapeId()), + Namespacing.shapeIdToImportFqn(toShapeId.toShapeId()) + ) } case object Any extends Type { def importFqn = Set(protobufFqn("any")) @@ -138,54 +146,119 @@ object ProtoIR { val fqn: Fqn = protobufFqn("Empty") } - private val smithyTranslateImportFqn = - Namespacing.namespaceToFqn("smithytranslate") - - val BigInteger = MessageType( - Fqn(Some(List("smithytranslate")), "BigInteger"), - smithyTranslateImportFqn - ) - val BigDecimal = MessageType( - Fqn(Some(List("smithytranslate")), "BigDecimal"), - smithyTranslateImportFqn - ) - val Timestamp = MessageType( - Fqn(Some(List("smithytranslate")), "Timestamp"), - smithyTranslateImportFqn - ) + private val alloyTypesImport = + Fqn(Some(List("alloy", "protobuf")), "types") + + private val alloyWrappersImport = + Fqn(Some(List("alloy", "protobuf")), "wrappers") + + object AlloyTypes { + val CompactUUID = RefType( + alloyFqn("CompactUUID"), + alloyTypesImport + ) + } + + object AlloyWrappers { + val BigInteger = RefType( + alloyFqn("BigIntegerValue"), + alloyWrappersImport + ) + val BigDecimal = RefType( + alloyFqn("BigDecimalValue"), + alloyWrappersImport + ) + val ShortValue = RefType( + alloyFqn("ShortValue"), + alloyWrappersImport + ) + val Fixed32 = RefType( + alloyFqn("Fixed32Value"), + alloyWrappersImport + ) + val SFixed32 = RefType( + alloyFqn("SFixed32Value"), + alloyWrappersImport + ) + val Fixed64 = RefType( + alloyFqn("Fixed64Value"), + alloyWrappersImport + ) + val SFixed64 = RefType( + alloyFqn("SFixed64Value"), + alloyWrappersImport + ) + val SInt32 = RefType( + alloyFqn("SInt32Value"), + alloyWrappersImport + ) + val SInt64 = RefType( + alloyFqn("SInt64Value"), + alloyWrappersImport + ) + val ByteValue = RefType( + alloyFqn("ByteValue"), + alloyWrappersImport + ) + val Timestamp = RefType( + alloyFqn("TimestampValue"), + alloyWrappersImport + ) + val CompactUUID = RefType( + alloyFqn("CompactUUIDValue"), + alloyWrappersImport + ) + val Document = RefType( + alloyFqn("DocumentValue"), + alloyWrappersImport + ) + } // https://github.com/protocolbuffers/protobuf/blob/178ebc179ede26bcaa85b39db127ebf099be3ef8/src/google/protobuf/wrappers.proto - trait Wrappers extends Type { - def importFqn = Set(protobufFqn("wrappers")) + sealed trait PredefinedType extends Type { def fqn: Fqn } - object Wrappers { - case object Double extends Wrappers { + sealed trait GoogleWrappers extends PredefinedType { + def importFqn = Set(protobufFqn("wrappers")) + } + + case object GoogleValue extends PredefinedType { + def importFqn: Set[Fqn] = Set(protobufFqn("struct")) + def fqn = protobufFqn("Value") + } + + case object GoogleTimestamp extends PredefinedType { + def importFqn: Set[Fqn] = Set(protobufFqn("timestamp")) + def fqn = protobufFqn("Timestamp") + } + + object GoogleWrappers { + case object Double extends GoogleWrappers { def fqn: Fqn = protobufFqn("DoubleValue") } - case object Float extends Wrappers { + case object Float extends GoogleWrappers { def fqn: Fqn = protobufFqn("FloatValue") } - case object Int64 extends Wrappers { + case object Int64 extends GoogleWrappers { def fqn: Fqn = protobufFqn("Int64Value") } - case object Uint64 extends Wrappers { + case object Uint64 extends GoogleWrappers { def fqn: Fqn = protobufFqn("UInt64Value") } - case object Int32 extends Wrappers { + case object Int32 extends GoogleWrappers { def fqn: Fqn = protobufFqn("Int32Value") } - case object Uint32 extends Wrappers { + case object Uint32 extends GoogleWrappers { def fqn: Fqn = protobufFqn("UInt32Value") } - case object Bool extends Wrappers { + case object Bool extends GoogleWrappers { def fqn: Fqn = protobufFqn("BoolValue") } - case object String extends Wrappers { + case object String extends GoogleWrappers { def fqn: Fqn = protobufFqn("StringValue") } - case object Bytes extends Wrappers { + case object Bytes extends GoogleWrappers { def fqn: Fqn = protobufFqn("BytesValue") } } diff --git a/modules/proto/src/smithytranslate/proto3/internals/Renderer.scala b/modules/proto/src/smithytranslate/proto3/internals/Renderer.scala index 49461bf..fcf0982 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/Renderer.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/Renderer.scala @@ -13,10 +13,11 @@ * limitations under the License. */ -package smithyproto +package smithytranslate package proto3 +package internals -object Renderer { +private[proto3] object Renderer { import ProtoIR._ import Text._ @@ -159,14 +160,13 @@ object Renderer { case Bool => "bool" case String => "string" case Bytes => "bytes" - case ty @ MapType(_, valueType) => - s"map<${renderType(ty.foldedKeyType)}, ${renderType(valueType)}>" + case MapType(keyType, valueType) => + s"map<${renderType(keyType)}, ${renderType(valueType)}>" case ListType(valueType) => s"repeated ${renderType(valueType)}" - case MessageType(fqn, _) => fqn.render - case EnumType(fqn, _) => fqn.render + case RefType(fqn, _) => fqn.render case Any => Any.fqn.render case Empty => Empty.fqn.render - case w: Wrappers => w.fqn.render + case w: PredefinedType => w.fqn.render } } diff --git a/modules/proto/src/smithytranslate/proto3/internals/Text.scala b/modules/proto/src/smithytranslate/proto3/internals/Text.scala index 1c94fe6..1d766e8 100644 --- a/modules/proto/src/smithytranslate/proto3/internals/Text.scala +++ b/modules/proto/src/smithytranslate/proto3/internals/Text.scala @@ -13,15 +13,15 @@ * limitations under the License. */ -package smithyproto +package smithytranslate.proto3.internals import scala.annotation.tailrec /** A language for building and rendering structured text, including newlines * and indentation. */ -sealed trait Text -object Text { +private[internals] sealed trait Text +private[internals] object Text { case class Line(string: String) extends Text case class Many(texts: List[Text]) extends Text case class Indent(text: Text) extends Text diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerRendererSuite.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerRendererSuite.scala index b24beea..0f054ac 100644 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerRendererSuite.scala +++ b/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerRendererSuite.scala @@ -13,21 +13,33 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals import munit._ import software.amazon.smithy.model.Model import software.amazon.smithy.model.validation.ValidatedResultException -import smithyproto.validation.ProtoValidator class CompilerRendererSuite extends FunSuite { test("top level - union") { val source = """|namespace com.example + | + |use alloy.proto#protoWrapped + | + |list StringList { + | member: String + |} + | + |map StringMap { + | key: String, + | value: String + |} | |union MyUnion { - | name: String, + | name: String | id: Integer + | stringList: StringList + | stringMap: StringMap |} |""".stripMargin @@ -35,10 +47,20 @@ class CompilerRendererSuite extends FunSuite { | |package com.example; | + |message StringList { + | repeated string value = 1; + |} + | + |message StringMap { + | map value = 1; + |} + | |message MyUnion { | oneof definition { | string name = 1; | int32 id = 2; + | com.example.StringList stringList = 3; + | com.example.StringMap stringMap = 4; | } |} |""".stripMargin @@ -131,268 +153,344 @@ class CompilerRendererSuite extends FunSuite { ) } - test("top level - document") { - val source = """|namespace com.example - | - |document SomeDoc - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; - | - |import "google/protobuf/any.proto"; - | - |message SomeDoc { - | google.protobuf.Any value = 1; - |} - |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } - - test("top level - string") { - val source = """|namespace com.example - | - |string SomeString - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; - | - |message SomeString { - | string value = 1; - |} - |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } - - test("top level - structure") { + test("document") { val source = """|namespace com.example - | - |structure MyStruct { - | @required - | value: String - |} - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; | - |message MyStruct { - | string value = 1; + |structure SomeDoc { + | value: Document |} |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } - test("top level - int") { - val source = """|namespace com.example - | - |integer SomeInt - |""".stripMargin val expected = """|syntax = "proto3"; - | - |package com.example; - | - |message SomeInt { - | int32 value = 1; - |} - |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } + | + |package com.example; + | + |import "google/protobuf/struct.proto"; + | + |message SomeDoc { + | google.protobuf.Value value = 1; + |} + |""".stripMargin - test("top level - long") { - val source = """|namespace com.example - | - |long SomeLong - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; - | - |message SomeLong { - | int64 value = 1; - |} - |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("top level - double") { + test("Primitive fields") { val source = """|namespace com.example - | - |double SomeDouble - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; | - |message SomeDouble { - | double value = 1; + |structure Struct { + | boolean: Boolean + | int: Integer + | long: Long + | byte: Byte + | short: Short + | float: Float + | double: Double + | bigInteger: BigInteger + | bigDecimal: BigDecimal + | blob: Blob + | document: Document + | string: String + | timestamp: Timestamp |} |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } - - test("top level - float") { - val source = """|namespace com.example - | - |float SomeFloat - |""".stripMargin val expected = """|syntax = "proto3"; - | - |package com.example; - | - |message SomeFloat { - | float value = 1; - |} - |""".stripMargin + | + |package com.example; + | + |import "google/protobuf/struct.proto"; + | + |import "google/protobuf/timestamp.proto"; + | + |message Struct { + | bool boolean = 1; + | int32 int = 2; + | int64 long = 3; + | int32 byte = 4; + | int32 short = 5; + | float float = 6; + | double double = 7; + | string bigInteger = 8; + | string bigDecimal = 9; + | bytes blob = 10; + | google.protobuf.Value document = 11; + | string string = 12; + | google.protobuf.Timestamp timestamp = 13; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) - } - test("top level - short") { + test("Primitive fields (wrapped)") { val source = """|namespace com.example - | - |short SomeShort - |""".stripMargin - val expected = """|syntax = "proto3"; | - |package com.example; + |use alloy.proto#protoWrapped | - |message SomeShort { - | int32 value = 1; + |structure Struct { + | @protoWrapped + | boolean: Boolean + | @protoWrapped + | int: Integer + | @protoWrapped + | long: Long + | @protoWrapped + | byte: Byte + | @protoWrapped + | short: Short + | @protoWrapped + | float: Float + | @protoWrapped + | double: Double + | @protoWrapped + | bigInteger: BigInteger + | @protoWrapped + | bigDecimal: BigDecimal + | @protoWrapped + | blob: Blob + | @protoWrapped + | document: Document + | @protoWrapped + | string: String + | @protoWrapped + | timestamp: Timestamp |} |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } - test("top level - bool") { - val source = """|namespace com.example - | - |boolean SomeBool - |""".stripMargin val expected = """|syntax = "proto3"; - | - |package com.example; - | - |message SomeBool { - | bool value = 1; - |} - |""".stripMargin + | + |package com.example; + | + |import "google/protobuf/wrappers.proto"; + | + |import "alloy/protobuf/wrappers.proto"; + | + |message Struct { + | google.protobuf.BoolValue boolean = 1; + | google.protobuf.Int32Value int = 2; + | google.protobuf.Int64Value long = 3; + | alloy.protobuf.ByteValue byte = 4; + | alloy.protobuf.ShortValue short = 5; + | google.protobuf.FloatValue float = 6; + | google.protobuf.DoubleValue double = 7; + | alloy.protobuf.BigIntegerValue bigInteger = 8; + | alloy.protobuf.BigDecimalValue bigDecimal = 9; + | google.protobuf.BytesValue blob = 10; + | alloy.protobuf.DocumentValue document = 11; + | google.protobuf.StringValue string = 12; + | alloy.protobuf.TimestampValue timestamp = 13; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("top level - bytes") { + test("Primitive references") { val source = """|namespace com.example - | - |blob SomeBlob - |""".stripMargin - val someBlob = """|syntax = "proto3"; | - |package com.example; + |use alloy.proto#protoWrapped + | + |boolean MyBoolean + |integer MyInt + |long MyLong + |byte MyByte + |short MyShort + |float MyFloat + |double MyDouble + |bigInteger MyBigInt + |bigDecimal MyBigDecimal + |blob MyBlob + |document MyDocument + |string MyString + |timestamp MyTimestamp | - |message SomeBlob { - | bytes value = 1; + |structure Struct { + | boolean: Boolean + | int: Integer + | long: Long + | byte: Byte + | short: Short + | float: Float + | double: Double + | bigInteger: BigInteger + | bigDecimal: BigDecimal + | blob: Blob + | document: Document + | string: String + | timestamp: Timestamp |} |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> someBlob)) - } - test("top level - big integer") { - val source = """|namespace com.example - | - |bigInteger SomeBigInt - |""".stripMargin val expected = """|syntax = "proto3"; - | - |package com.example; - | - |import "smithytranslate/definitions.proto"; - | - |message SomeBigInt { - | smithytranslate.BigInteger value = 1; - |} - |""".stripMargin + | + |package com.example; + | + |import "google/protobuf/struct.proto"; + | + |import "google/protobuf/timestamp.proto"; + | + |message Struct { + | bool boolean = 1; + | int32 int = 2; + | int64 long = 3; + | int32 byte = 4; + | int32 short = 5; + | float float = 6; + | double double = 7; + | string bigInteger = 8; + | string bigDecimal = 9; + | bytes blob = 10; + | google.protobuf.Value document = 11; + | string string = 12; + | google.protobuf.Timestamp timestamp = 13; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("top level - big decimal") { - val source = """|namespace com.example - | - |bigDecimal SomeBigDec - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; - | - |import "smithytranslate/definitions.proto"; - | - |message SomeBigDec { - | smithytranslate.BigDecimal value = 1; - |} - |""".stripMargin - convertWithApiCheck( - source, - Map("com/example/example.proto" -> expected) - ) - } - - test("top level - timestamp") { + test("Primitives reference (wrapped)") { val source = """|namespace com.example - | - |timestamp SomeTs - |""".stripMargin - val expected = """|syntax = "proto3"; | - |package com.example; + |use alloy.proto#protoWrapped + | + |@protoWrapped + |boolean MyBoolean + |@protoWrapped + |integer MyInt + |@protoWrapped + |long MyLong + |@protoWrapped + |byte MyByte + |@protoWrapped + |short MyShort + |@protoWrapped + |float MyFloat + |@protoWrapped + |double MyDouble + |@protoWrapped + |bigInteger MyBigInt + |@protoWrapped + |bigDecimal MyBigDecimal + |@protoWrapped + |blob MyBlob + |@protoWrapped + |document MyDocument + |@protoWrapped + |string MyString + |@protoWrapped + |timestamp MyTimestamp | - |import "smithytranslate/definitions.proto"; - | - |message SomeTs { - | smithytranslate.Timestamp value = 1; + |structure Struct { + | boolean: MyBoolean + | int: MyInt + | long: MyLong + | byte: MyByte + | short: MyShort + | float: MyFloat + | double: MyDouble + | bigInteger: MyBigInt + | bigDecimal: MyBigDecimal + | blob: MyBlob + | document: MyDocument + | string: MyString + | timestamp: MyTimestamp |} |""".stripMargin - convertWithApiCheck( - source, - Map("com/example/example.proto" -> expected) - ) - } - test("proto top-level deprecated") { - val source = """|$version: "2" - | - |namespace another.namespace - | - |@deprecated - |string SomeString - |""".stripMargin val expected = """|syntax = "proto3"; | - |package another.namespace; + |package com.example; | - |message SomeString { - | string value = 1 [deprecated = true]; + |import "google/protobuf/struct.proto"; + | + |import "google/protobuf/timestamp.proto"; + | + |message MyBoolean { + | bool value = 1; + |} + | + |message MyInt { + | int32 value = 1; + |} + | + |message MyLong { + | int64 value = 1; + |} + | + |message MyByte { + | int32 value = 1; + |} + | + |message MyShort { + | int32 value = 1; + |} + | + |message MyFloat { + | float value = 1; + |} + | + |message MyDouble { + | double value = 1; + |} + | + |message MyBigInt { + | string value = 1; + |} + | + |message MyBigDecimal { + | string value = 1; + |} + | + |message MyBlob { + | bytes value = 1; + |} + | + |message MyDocument { + | google.protobuf.Value value = 1; + |} + | + |message MyString { + | string value = 1; + |} + | + |message MyTimestamp { + | google.protobuf.Timestamp value = 1; + |} + | + |message Struct { + | com.example.MyBoolean boolean = 1; + | com.example.MyInt int = 2; + | com.example.MyLong long = 3; + | com.example.MyByte byte = 4; + | com.example.MyShort short = 5; + | com.example.MyFloat float = 6; + | com.example.MyDouble double = 7; + | com.example.MyBigInt bigInteger = 8; + | com.example.MyBigDecimal bigDecimal = 9; + | com.example.MyBlob blob = 10; + | com.example.MyDocument document = 11; + | com.example.MyString string = 12; + | com.example.MyTimestamp timestamp = 13; |} |""".stripMargin - convertCheck(source, Map("another/namespace/namespace.proto" -> expected)) + convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("proto structure deprecated") { + test("deprecated field") { val source = """|$version: "2" | |namespace another.namespace | - |structure MyStruct { - | @required + |structure Struct { | @deprecated | value: String |} |""".stripMargin + val expected = """|syntax = "proto3"; | |package another.namespace; | - |message MyStruct { + |message Struct { | string value = 1 [deprecated = true]; |} |""".stripMargin @@ -401,156 +499,134 @@ class CompilerRendererSuite extends FunSuite { test("protoNumType") { val source = """|namespace com.example - | - |use alloy.proto#protoNumType - | - |structure LongNumbers { - | @protoNumType("SIGNED") - | signed: Long, - | - | @protoNumType("UNSIGNED") - | unsigned: Long, - | - | @protoNumType("FIXED") - | fixed: Long, - | - | @protoNumType("FIXED_SIGNED") - | FIXED_SIGNED: Long - |} - | - |structure IntNumbers { - | @protoNumType("SIGNED") - | signed: Integer, - | - | @protoNumType("UNSIGNED") - | unsigned: Integer, - | - | @protoNumType("FIXED") - | fixed: Integer, - | - | @protoNumType("FIXED_SIGNED") - | FIXED_SIGNED: Integer - |} - | - |structure RequiredLongNumbers { - | @protoNumType("SIGNED") - | @required - | signed: Long, - | - | @protoNumType("UNSIGNED") - | @required - | unsigned: Long, - | - | @protoNumType("FIXED") - | @required - | fixed: Long, - | - | @protoNumType("FIXED_SIGNED") - | @required - | FIXED_SIGNED: Long - |} - | - |structure RequiredIntNumbers { - | @protoNumType("SIGNED") - | @required - | signed: Integer, - | - | @protoNumType("UNSIGNED") - | @required - | unsigned: Integer, - | - | @protoNumType("FIXED") - | @required - | fixed: Integer, - | - | @protoNumType("FIXED_SIGNED") - | @required - | FIXED_SIGNED: Integer - |} - |""".stripMargin + | + |use alloy.proto#protoNumType + |use alloy.proto#protoWrapped + | + |structure LongNumbers { + | @protoNumType("SIGNED") + | signed: Long, + | + | @protoNumType("UNSIGNED") + | unsigned: Long, + | + | @protoNumType("FIXED") + | fixed: Long, + | + | @protoNumType("FIXED_SIGNED") + | fixed_signed: Long + |} + | + |structure IntNumbers { + | @protoNumType("SIGNED") + | signed: Integer, + | + | @protoNumType("UNSIGNED") + | unsigned: Integer, + | + | @protoNumType("FIXED") + | fixed: Integer, + | + | @protoNumType("FIXED_SIGNED") + | fixed_signed: Integer + |} + | + |structure WrappedLongNumbers { + | @protoNumType("SIGNED") + | @protoWrapped + | signed: Long, + | + | @protoNumType("UNSIGNED") + | @protoWrapped + | unsigned: Long, + | + | @protoNumType("FIXED") + | @protoWrapped + | fixed: Long, + | + | @protoNumType("FIXED_SIGNED") + | @protoWrapped + | fixed_signed: Long + |} + | + |structure WrappedIntNumbers { + | @protoNumType("SIGNED") + | @protoWrapped + | signed: Integer, + | + | @protoNumType("UNSIGNED") + | @protoWrapped + | unsigned: Integer, + | + | @protoNumType("FIXED") + | @protoWrapped + | fixed: Integer, + | + | @protoNumType("FIXED_SIGNED") + | @protoWrapped + | fixed_signed: Integer + |} + |""".stripMargin val expected = """|syntax = "proto3"; | |package com.example; | + |import "alloy/protobuf/wrappers.proto"; + | |import "google/protobuf/wrappers.proto"; | |message LongNumbers { - | google.protobuf.Int64Value signed = 1; - | google.protobuf.UInt64Value unsigned = 2; - | google.protobuf.Int64Value fixed = 3; - | google.protobuf.Int64Value FIXED_SIGNED = 4; - |} - | - |message IntNumbers { - | google.protobuf.Int32Value signed = 1; - | google.protobuf.UInt32Value unsigned = 2; - | google.protobuf.Int32Value fixed = 3; - | google.protobuf.Int32Value FIXED_SIGNED = 4; - |} - | - |message RequiredLongNumbers { | sint64 signed = 1; | uint64 unsigned = 2; | fixed64 fixed = 3; - | sfixed64 FIXED_SIGNED = 4; + | sfixed64 fixed_signed = 4; |} | - |message RequiredIntNumbers { + |message IntNumbers { | sint32 signed = 1; | uint32 unsigned = 2; | fixed32 fixed = 3; - | sfixed32 FIXED_SIGNED = 4; - |}""".stripMargin + | sfixed32 fixed_signed = 4; + |} + | + |message WrappedLongNumbers { + | alloy.protobuf.SInt64Value signed = 1; + | google.protobuf.UInt64Value unsigned = 2; + | alloy.protobuf.Fixed64Value fixed = 3; + | alloy.protobuf.SFixed64Value fixed_signed = 4; + |} + | + |message WrappedIntNumbers { + | alloy.protobuf.SInt32Value signed = 1; + | google.protobuf.UInt32Value unsigned = 2; + | alloy.protobuf.Fixed32Value fixed = 3; + | alloy.protobuf.SFixed32Value fixed_signed = 4; + |} + |""".stripMargin convertCheck( source, Map("com/example/example.proto" -> expected) ) } - test("inlined sparse maps") { + test("maps") { val source = """|namespace com.example - | - |@sparse - |map StringMap { - | key: String, - | value: String - |} - | - |structure Foo { - | object: StringMap - |} - |""".stripMargin - val expected = """|syntax = "proto3"; | - |package com.example; + |structure MapItem { + | @required + | name: String + |} | - |import "google/protobuf/wrappers.proto"; + |map Map { + | key: String, + | value: MapItem + |} | - |message Foo { - | map object = 1; + |structure Foo { + | values: Map |} |""".stripMargin - convertCheck(source, Map("com/example/example.proto" -> expected)) - } - - test("inlined maps message") { - val source = """|namespace com.example - | - |structure MapItem { - | @required - | name: String - |} - | - |map Map { - | key: String, - | value: MapItem - |} - | - |structure Foo { - | values: Map - |} - |""".stripMargin val expected = """|syntax = "proto3"; | |package com.example; @@ -566,202 +642,249 @@ class CompilerRendererSuite extends FunSuite { convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("inlined maps") { + test("maps (bis)") { val source = """|namespace com.example - | - |map StringMap { - | key: String, - | value: Integer - |} - | - |structure Foo { - | strings: StringMap - |} - |""".stripMargin - val expected = """|syntax = "proto3"; | - |package com.example; + |map StringMap { + | key: String, + | value: Integer + |} | - |message Foo { - | map strings = 1; + |structure Foo { + | strings: StringMap + | @alloy.proto#protoWrapped + | wrappedStrings: StringMap |} |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package com.example; + | + |message StringMap { + | map value = 1; + |} + | + |message Foo { + | map strings = 1; + | com.example.StringMap wrappedStrings = 2; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("inlined lists") { + test("maps (wrapped)") { val source = """|namespace com.example - | - |list StringList { - | member: String - |} - | - |structure Foo { - | strings: StringList - |} - |""".stripMargin - val expected = """|syntax = "proto3"; | - |package com.example; + |use alloy.proto#protoWrapped + | + |@protoWrapped + |map StringMap { + | key: String, + | @protoWrapped + | value: Integer + |} | - |message Foo { - | repeated string strings = 1; + |structure Foo { + | strings: StringMap |} |""".stripMargin + val expected = + """|syntax = "proto3"; + | + |package com.example; + | + |import "google/protobuf/wrappers.proto"; + | + |message StringMap { + | map value = 1; + |} + | + |message Foo { + | com.example.StringMap strings = 1; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("inlined sparse lists") { + test("lists") { val source = """|namespace com.example - | - |@sparse - |list StringList { - | member: String - |} - | - |structure Foo { - | strings: StringList - |} - |""".stripMargin - val expected = """|syntax = "proto3"; - | - |package com.example; | - |import "google/protobuf/wrappers.proto"; + |list StringList { + | member: String + |} | - |message Foo { - | repeated google.protobuf.StringValue strings = 1; + |structure Foo { + | strings: StringList |} |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package com.example; + | + |message Foo { + | repeated string strings = 1; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("inlined lists message") { + test("lists (bis)") { val source = """|namespace com.example - | - |structure ListItem { - | @required - | name: String - |} - | - |list List { - | member: ListItem - |} - | - |structure Foo { - | strings: List - |} - |""".stripMargin + | + |structure ListItem { + | @required + | name: String + |} + | + |list List { + | member: ListItem + |} + | + |structure Foo { + | strings: List + | @alloy.proto#protoWrapped + | wrappedStrings: List + |} + |""".stripMargin val expected = """|syntax = "proto3"; - | - |package com.example; - | - |message ListItem { - | string name = 1; - |} - | - |message Foo { - | repeated com.example.ListItem strings = 1; - |} - |""".stripMargin + | + |package com.example; + | + |message ListItem { + | string name = 1; + |} + | + |message List { + | repeated com.example.ListItem value = 1; + |} + | + |message Foo { + | repeated com.example.ListItem strings = 1; + | com.example.List wrappedStrings = 2; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } - test("inline list with a union") { + test("lists (wrapped)") { val source = """|namespace com.example | - |union MyUnion { - | name: String, - | id: Integer - |} - | - |structure UnionStruct { - | @required - | value: MyUnion - |} + |use alloy.proto#protoWrapped | - |list ListOfUnion { - | member: UnionStruct + |@protoWrapped + |list StringList { + | @protoWrapped + | member: String |} | - |structure Unions { - | @required - | values: ListOfUnion + |structure Foo { + | strings: StringList |} |""".stripMargin val expected = """|syntax = "proto3"; | |package com.example; | - |message MyUnion { - | oneof definition { - | string name = 1; - | int32 id = 2; - | } - |} + |import "google/protobuf/wrappers.proto"; | - |message UnionStruct { - | com.example.MyUnion value = 1; + |message StringList { + | repeated google.protobuf.StringValue value = 1; |} | - |message Unions { - | repeated com.example.UnionStruct values = 1; - |}""".stripMargin + |message Foo { + | com.example.StringList strings = 1; + |} + |""".stripMargin convertCheck(source, Map("com/example/example.proto" -> expected)) } test("transitive structure with protoEnabled") { val source = """|namespace test - | - |use alloy.proto#protoEnabled - | - |@protoEnabled - |structure Test { - | o: Other - |} - | - |structure Other { - | s: String - |} - |""".stripMargin - val expected = """|syntax = "proto3"; | - |package test; + |use alloy.proto#protoEnabled | - |import "google/protobuf/wrappers.proto"; + |@protoEnabled + |structure Test { + | o: Other + |} | - |message Test { - | test.Other o = 1; + |structure Other { + | s: String |} + |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package test; + | + |message Test { + | test.Other o = 1; + |} + | + |message Other { + | string s = 1; + |} + |""".stripMargin + convertCheck( + source, + Map("test/definitions.proto" -> expected), + allShapes = false + ) + + } + + test("uuid translates to string by default") { + val source = """|namespace test + | + |use alloy#uuidFormat | - |message Other { - | google.protobuf.StringValue s = 1; + |@uuidFormat + |string MyUUID + | + |structure Test { + | id: MyUUID |} |""".stripMargin - convertCheck(source, Map("test/definitions.proto" -> expected)) + val expected = """|syntax = "proto3"; + | + |package test; + | + |message Test { + | string id = 1; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) } - test("uuid can be used") { + test( + "uuid translates to message by when annotated with @protoCompactUUID" + ) { val source = """|namespace test | - |use smithytranslate#UUID + |use alloy#uuidFormat + |use alloy.proto#protoCompactUUID + | + |@uuidFormat + |@protoCompactUUID + |string MyUUID | |structure Test { - | @required - | id: UUID + | id: MyUUID |} |""".stripMargin val expected = """|syntax = "proto3"; | |package test; | - |import "smithytranslate/definitions.proto"; + |import "alloy/protobuf/types.proto"; | |message Test { - | smithytranslate.UUID id = 1; + | alloy.protobuf.CompactUUID id = 1; |}""".stripMargin convertCheck( @@ -799,68 +922,80 @@ class CompilerRendererSuite extends FunSuite { test("service with protoEnabled") { val source = """|namespace test - | - |use alloy.proto#protoEnabled - | - |@protoEnabled - |service Test { - | operations: [Op] - |} - | - |@http(method: "POST", uri: "/test", code: 200) - |operation Op { - | input: Struct, - | output: Struct - |} - | - |structure Struct { - | s: String - |} - | - |/// This one should not be converted - |service Other { - | operations: [Op] - |} - |""".stripMargin + | + |use alloy.proto#protoEnabled + | + |@protoEnabled + |service Test { + | operations: [Op] + |} + | + |@http(method: "POST", uri: "/test", code: 200) + |operation Op { + | input: Struct, + | output: Struct + |} + | + |structure Struct { + | s: String + |} + | + |/// This one should not be converted + |service Other { + | operations: [Op] + |} + |""".stripMargin val expected = """|syntax = "proto3"; | |package test; | - |import "google/protobuf/wrappers.proto"; - | |service Test { | rpc Op(test.Struct) returns (test.Struct); |} | |message Struct { - | google.protobuf.StringValue s = 1; + | string s = 1; |} |""".stripMargin - convertCheck(source, Map("test/definitions.proto" -> expected)) + convertCheck( + source, + Map("test/definitions.proto" -> expected), + allShapes = false + ) } - test("enum with protoIndex") { + test("enum without protoIndex") { val source = """|$version: "2" |namespace test | |use alloy.proto#protoIndex - |use alloy.proto#protoEnabled - | - |@protoEnabled - |service Test { - | operations: [Op] - |} | - |@http(method: "POST", uri: "/test", code: 200) - |operation Op { - | input: Struct, - | output: Struct + |enum LoveProto { + | YES + | NO |} + |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package test; + | + |enum LoveProto { + | YES = 0; + | NO = 1; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) + } + + test("enum with protoIndex") { + val source = """|$version: "2" + |namespace test | - |structure Struct { - | s: LoveProto - | } + |use alloy.proto#protoIndex | |enum LoveProto { | @protoIndex(0) @@ -873,13 +1008,59 @@ class CompilerRendererSuite extends FunSuite { | |package test; | - |service Test { - | rpc Op(test.Struct) returns (test.Struct); - |} + |enum LoveProto { + | YES = 0; + | NO = 2; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) + } + + test("intEnum without protoIndex") { + val source = """|$version: "2" + |namespace test + | + |use alloy.proto#protoIndex + | + |intEnum LoveProto { + | YES = 0 + | NO = 2 + |} + |""".stripMargin + val expected = """|syntax = "proto3"; | - |message Struct { - | test.LoveProto s = 1; - |} + |package test; + | + |enum LoveProto { + | YES = 0; + | NO = 2; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) + } + + test("intEnum with protoIndex") { + val source = """|$version: "2" + |namespace test + | + |use alloy.proto#protoIndex + | + |intEnum LoveProto { + | @protoIndex(0) + | YES = 1 + | @protoIndex(2) + | NO = 3 + |} + |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package test; | |enum LoveProto { | YES = 0; @@ -892,6 +1073,96 @@ class CompilerRendererSuite extends FunSuite { ) } + test("open string enum") { + val source = """|$version: "2" + |namespace test + | + |use alloy#openEnum + | + |structure EnumWrapper { + | value: LoveProto + |} + | + |@openEnum + |enum LoveProto { + | YES + | NO + |} + |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package test; + | + |message EnumWrapper { + | string value = 1; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) + } + + test("open intEnum") { + val source = """|$version: "2" + |namespace test + | + |use alloy#openEnum + | + |structure EnumWrapper { + | value: LoveProto + |} + | + |@openEnum + |intEnum LoveProto { + | YES = 1 + | NO = 3 + |} + |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package test; + | + |message EnumWrapper { + | int32 value = 1; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) + } + + test("conflicting enum values") { + val source = """|$version: "2" + |namespace test + | + |enum MyStringEnum { + | FOO + |} + | + |intEnum MyIntEnum { + | FOO = 0 + |} + |""".stripMargin + val expected = """|syntax = "proto3"; + | + |package test; + | + |enum MyStringEnum { + | MYSTRINGENUM_FOO = 0; + |} + | + |enum MyIntEnum { + | MYINTENUM_FOO = 0; + |}""".stripMargin + + convertCheck( + source, + Map("test/definitions.proto" -> expected) + ) + } + test("union with protoIndex") { val source = """|$version: "2" |namespace test @@ -923,7 +1194,7 @@ class CompilerRendererSuite extends FunSuite { ) } - test("union with @protoInlinedOneOf and @protoIndex") { + test("union with @protoInlinedOneOf and @protoIndex (invalid)") { val source = """|$version: "2" |namespace test | @@ -1081,13 +1352,15 @@ class CompilerRendererSuite extends FunSuite { test("multiple namespaces") { def src(ns: String) = s"""|namespace com.$ns | - |string SomeString + |structure Struct { + | value: String + |} |""".stripMargin def expected(ns: String) = s"""|syntax = "proto3"; | |package com.$ns; | - |message SomeString { + |message Struct { | string value = 1; |} |""".stripMargin @@ -1112,7 +1385,9 @@ class CompilerRendererSuite extends FunSuite { | |namespace another.namespace | - |string SomeString + |structure Struct { + | value: String + |} |""".stripMargin val expected = """|syntax = "proto3"; | @@ -1121,75 +1396,36 @@ class CompilerRendererSuite extends FunSuite { | |package another.namespace; | - |message SomeString { + |message Struct { | string value = 1; |} |""".stripMargin convertCheck(source, Map("another/namespace/namespace.proto" -> expected)) } - /** Perform the same check as convertCheck but include the smithytranslate - * namespace. To do so it prepends the proto api to your `expected` value. - */ - private def convertWithApiCheck( - source: String, - expected: Map[String, String] - )(implicit loc: Location): Unit = { - val expectedApi = s"""|syntax = "proto3"; - | - |package smithytranslate; - | - |message BigDecimal { - | string value = 1; - |} - | - |message BigInteger { - | string value = 1; - |} - | - |message Timestamp { - | int64 value = 1; - |} - | - |message UUID { - | int64 upper_bits = 1; - | int64 lower_bits = 2; - |} - |""".stripMargin - val newExpected = Map( - "smithytranslate/definitions.proto" -> expectedApi - ) ++ expected - convertCheck(source, newExpected, excludeProtoApi = false) - } - private def convertCheck( source: String, expected: Map[String, String], - excludeProtoApi: Boolean = true + allShapes: Boolean = true )(implicit loc: Location): Unit = { convertChecks( Map("inlined-in-test.smithy" -> source), expected, - excludeProtoApi + allShapes ) } private def convertChecks( sources: Map[String, String], expected: Map[String, String], - excludeProtoApi: Boolean = true + allShapes: Boolean = true )(implicit loc: Location): Unit = { def render(srcs: Map[String, String]): List[(String, String)] = { val m = { val assembler = Model .assembler() .discoverModels() - .addShapes( - smithytranslate.BigInteger.shape, - smithytranslate.BigDecimal.shape, - smithytranslate.Timestamp.shape, - smithytranslate.UUID.shape - ) + srcs.foreach { case (name, src) => assembler.addUnparsedModel(name, src) } @@ -1198,36 +1434,26 @@ class CompilerRendererSuite extends FunSuite { .assemble() .unwrap() } - val c = new Compiler() - val res = c.compile(m) - if (res.isEmpty) { fail("Expected compiler output") } + val c = new Compiler(m, allShapes = allShapes) + val res = c.compile() + if (res.isEmpty) { fail("Compiler didn't produce any output") } res.map { of => val fileName = of.path.mkString("/") fileName -> Renderer.render(of.unit) } } - val actual = render(sources).sortWith { case ((name1, _), (_, _)) => - name1.startsWith("smithytranslate") - } - ProtoValidator.run(actual: _*) - val exclude = - if (excludeProtoApi) - Set("smithytranslate/definitions.proto") - else Set.empty[String] - - val finalFiles = actual.collect { - case (name, contents) if !exclude(name) => (name, contents) - }.toMap + val renderedFiles = render(sources).toMap // Checking that we get the same keyset as expected - assertEquals(finalFiles.keySet, expected.keySet) + assertEquals(renderedFiles.keySet, expected.keySet) // Checking that all contents match for { (file, content) <- expected } { - assertEquals(finalFiles(file).trim(), content.trim()) + assertEquals(renderedFiles(file).trim(), content.trim()) } + ProtoValidator.run(renderedFiles.toSeq: _*) } } diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerSuite.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerSuite.scala index c7623a1..18b77d8 100644 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerSuite.scala +++ b/modules/proto/tests/src/smithytranslate/proto3/internals/CompilerSuite.scala @@ -13,19 +13,22 @@ * limitations under the License. */ -package smithyproto.proto3 +package smithytranslate.proto3.internals import munit._ -import smithyproto.proto3.ProtoIR._ -import smithyproto.proto3.ProtoIR.Statement._ +import ProtoIR._ +import Statement._ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import alloy.proto.ProtoEnabledTrait +import software.amazon.smithy.model.shapes.ShapeId class CompilerSuite extends FunSuite { - private val someString = TopLevelDef.MessageDef( + private val someStruct = TopLevelDef.MessageDef( Message( - "SomeString", + "Struct", List( MessageElement.FieldElement( Field( @@ -42,15 +45,26 @@ class CompilerSuite extends FunSuite { test("compile a simple smithy model") { val namespace = "com.example" - val sut = new Compiler() val model = { val mb = Model.builder() mb.addShape( - StringShape.builder().id(s"$namespace#SomeString").build() + StringShape + .builder() + .id(ShapeId.fromParts("com.example", "MyString")) + .build() + ) + mb.addShape( + StructureShape + .builder() + .addTrait(new ProtoEnabledTrait()) + .id(s"$namespace#Struct") + .addMember("value", ShapeId.fromParts("com.example", "MyString")) + .build() ) mb.build() } - val actual = sut.compile(model) + val sut = new Compiler(model, allShapes = true) + val actual = sut.compile() val expected = List( OutputFile( List( @@ -64,7 +78,7 @@ class CompilerSuite extends FunSuite { ), List( TopLevelStatement( - someString + someStruct ) ), List.empty @@ -114,15 +128,26 @@ class CompilerSuite extends FunSuite { private def namespaceTest(namespace: String, expectedFilePath: List[String])( implicit loc: Location ): Unit = { - val sut = new Compiler() val model = { val mb = Model.builder() mb.addShape( - StringShape.builder().id(s"$namespace#SomeString").build() + StringShape + .builder() + .id(ShapeId.fromParts("com.example", "MyString")) + .build() + ) + mb.addShape( + StructureShape + .builder() + .addTrait(new ProtoEnabledTrait()) + .id(s"$namespace#Struct") + .addMember("value", ShapeId.fromParts("com.example", "MyString")) + .build() ) mb.build() } - val actual = sut.compile(model) + val sut = new Compiler(model, allShapes = true) + val actual = sut.compile() val expected = List( OutputFile( expectedFilePath, @@ -132,7 +157,7 @@ class CompilerSuite extends FunSuite { ), List( TopLevelStatement( - someString + someStruct ) ), List.empty diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/ModelPrePocessorSpec.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/ModelPrePocessorSpec.scala deleted file mode 100644 index 369792e..0000000 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/ModelPrePocessorSpec.scala +++ /dev/null @@ -1,415 +0,0 @@ -/* Copyright 2022 Disney Streaming - * - * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://disneystreaming.github.io/TOST-1.0.txt - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package smithyproto.proto3 - -import munit._ -import software.amazon.smithy.model.Model -import software.amazon.smithy.model.shapes._ -import software.amazon.smithy.build.ProjectionTransformer -import software.amazon.smithy.build.TransformContext - -import scala.jdk.CollectionConverters._ -import scala.jdk.OptionConverters.RichOptional - -class ModelPrePocessorSpec extends FunSuite { - private def transitiveIsApplied( - allowedNamespace: Option[String] - )(implicit loc: Location) = { - val smithy = s"""|namespace test - | - |use alloy.proto#protoEnabled - | - |@protoEnabled - |structure Test { - | s: String - |} - | - |/// This one should not be converted - |structure Other { - | i: Integer - |} - |""".stripMargin - checkTransformer( - smithy, - ModelPreProcessor.transformers.Transitive(allowedNamespace) - ) { case (original, transformed) => - val removed = Set( - ShapeId.from("test#Other$i"), - ShapeId.from("test#Other") - ) - removed.foreach { sId => - assertEquals(original.getShape(sId).isPresent(), true) - assertEquals(transformed.getShape(sId).isPresent(), false) - } - val kept = Set( - ShapeId.from("test#Test$s"), - ShapeId.from("test#Test") - ) - kept.foreach { sId => - assertEquals(original.getShape(sId).isPresent(), true) - assertEquals(transformed.getShape(sId).isPresent(), true) - } - } - } - - test("apply Transitive on protoEnabled w/o an allowed namespace") { - transitiveIsApplied(None) - } - - test("apply Transitive on protoEnabled w/ an allowed namespace") { - transitiveIsApplied(Some("test")) - } - - test("apply Transitive does nothing if namespace is excluded") { - val smithy = s"""|namespace test - | - |use alloy.proto#protoEnabled - | - |@protoEnabled - |structure Test { - | s: String - |} - | - |/// This one should not be converted - |structure Other { - | i: Integer - |} - |""".stripMargin - checkTransformer( - smithy, - ModelPreProcessor.transformers.Transitive(Some("other_ns")) - ) { case (original, transformed) => - val removed = Set( - ShapeId.from("test#Test$s"), - ShapeId.from("test#Test"), - ShapeId.from("test#Other$i"), - ShapeId.from("test#Other") - ) - removed.foreach { sId => - assertEquals(original.getShape(sId).isPresent(), true) - assertEquals(transformed.getShape(sId).isPresent(), false) - } - } - } - - test("smithytranslate UUID is not included if alloy#UUID is not used") { - val smithy = s"""|namespace test - | - |structure Test { - | @required - | id: String - |} - |""".stripMargin - checkTransformer( - smithy, - ModelPreProcessor.transformers.CompactUUID - ) { case (original, transformed) => - assertEquals( - original.getShape(ShapeId.from("alloy#UUID")).isPresent(), - true - ) - assertEquals( - transformed.getShape(ShapeId.from("smithytranslate#UUID")).isPresent(), - false - ) - } - } - - test("alloy#UUID is converted to smithytranslate#UUID") { - val smithy = s"""|namespace test - | - |use alloy#UUID - | - |structure Test { - | @required - | id: UUID - |} - |""".stripMargin - checkTransformer( - smithy, - ModelPreProcessor.transformers.CompactUUID - ) { case (original, transformed) => - val removed = Set( - ShapeId.from("alloy#UUID") - ) - removed.foreach { sId => - assertEquals(original.getShape(sId).isPresent(), true) - assertEquals(transformed.getShape(sId).isPresent(), false) - } - assert( - transformed.getShape(ShapeId.from("smithytranslate#UUID")).isPresent() - ) - } - } - - def testPreludeReplacements( - name: String, - smithyShape: String, - kept: Set[ShapeId], - removed: Set[ShapeId] - ) = { - - /** Here, only BigInteger is used. We expect smithytranslate#BigInteger to - * be included, but not smithytranslate#BigDecimal or - * smithytranslate#Timestamp - */ - test( - s"PreludeReplacements - $name" - ) { - val smithy = s"""|namespace test - | - |use alloy.proto#protoEnabled - | - |@protoEnabled - |structure Test { - | s: $smithyShape - |} - |""".stripMargin - checkTransformer( - smithy, - ModelPreProcessor.transformers.Transitive(None) - ) { case (_, pruned) => - val resultModel = process( - pruned, - ModelPreProcessor.transformers.PreludeReplacements - ) - kept.foreach { sId => - assertEquals(pruned.getShape(sId).isPresent(), false) - assertEquals(resultModel.getShape(sId).isPresent(), true) - } - removed.foreach { sId => - assertEquals(pruned.getShape(sId).isPresent(), false) - assertEquals(resultModel.getShape(sId).isPresent(), false) - } - } - } - } - - testPreludeReplacements( - "keep big int", - "BigInteger", - Set( - ShapeId.from("smithytranslate#BigInteger") - ), - Set( - ShapeId.from("smithytranslate#BigDecimal"), - ShapeId.from("smithytranslate#Timestamp") - ) - ) - - testPreludeReplacements( - "keep timestamp", - "Timestamp", - Set( - ShapeId.from("smithytranslate#Timestamp") - ), - Set( - ShapeId.from("smithytranslate#BigDecimal"), - ShapeId.from("smithytranslate#BigInteger") - ) - ) - - testPreludeReplacements( - "keep big decimal", - "BigDecimal", - Set( - ShapeId.from("smithytranslate#BigDecimal") - ), - Set( - ShapeId.from("smithytranslate#Timestamp"), - ShapeId.from("smithytranslate#BigInteger") - ) - ) - - test("apply PreventEnumConflicts") { - val smithy = - """|$version: "2" - |namespace test - | - |enum Enum1 { - | VUNIQUE1 - | VCONFLICT - |} - | - |enum Enum2{ - | VUNIQUE2 - | VCONFLICT - |} - | - |enum Enum3{ - | VUNIQUE3 - | @enumValue("VCONFLICT") - | VCONFLICT3 - |} - | - |intEnum Enum4{ - | @enumValue(1) - | VUNIQUE4 - | @enumValue(2) - | VCONFLICT3 - |} - | - |enum NoConflict{ - | V1 - | V2 - | } - |""".stripMargin - checkTransformer( - smithy, - ModelPreProcessor.transformers.PreventEnumConflicts - ) { case (original, transformed) => - def getEnumNames(m: Model, shapeId: ShapeId): List[String] = { - m.getShape(shapeId) - .toScala - .toList - .collect { - case shape: EnumShape => - shape.getMemberNames.asScala.toList - case shape: IntEnumShape => - shape.getMemberNames.asScala.toList - } - .flatten - } - - assertEquals( - getEnumNames(original, ShapeId.from("test#Enum1")), - List("VUNIQUE1", "VCONFLICT") - ) - assertEquals( - getEnumNames(original, ShapeId.from("test#Enum2")), - List("VUNIQUE2", "VCONFLICT") - ) - assertEquals( - getEnumNames(original, ShapeId.from("test#Enum3")), - List("VUNIQUE3", "VCONFLICT3") - ) - - assertEquals( - getEnumNames(transformed, ShapeId.from("test#Enum1")), - List("VUNIQUE1", "ENUM1_VCONFLICT") - ) - assertEquals( - getEnumNames(transformed, ShapeId.from("test#Enum2")), - List("VUNIQUE2", "ENUM2_VCONFLICT") - ) - assertEquals( - getEnumNames(transformed, ShapeId.from("test#NoConflict")), - List("V1", "V2") - ) - assertEquals( - getEnumNames(transformed, ShapeId.from("test#Enum3")), - List("VUNIQUE3", "ENUM3_VCONFLICT3") - ) - assertEquals( - getEnumNames(transformed, ShapeId.from("test#Enum4")), - List("VUNIQUE4", "ENUM4_VCONFLICT3") - ) - } - } - - test("apply PreventEnumConflicts - across namespace") { - val smithyTest = - """|$version: "2" - |namespace test - | - |enum Enum1 { - | VCONFLICT - |} - | - |enum Enum2 { - | VCONFLICT - |} - |""".stripMargin - - val other = - """|$version: "2" - |namespace a.ns - | - |enum OtherEnum { - | VCONFLICT - |} - |""".stripMargin - val original = buildModel(smithyTest, other) - val transformed = - process(original, ModelPreProcessor.transformers.PreventEnumConflicts) - def getEnumNames(m: Model, shapeId: ShapeId): List[String] = { - m.getShape(shapeId) - .toScala - .toList - .collect { - case shape: EnumShape => - shape.getMemberNames.asScala.toList - case shape: IntEnumShape => - shape.getMemberNames.asScala.toList - } - .flatten - } - - assertEquals( - getEnumNames(original, ShapeId.from("test#Enum1")), - List("VCONFLICT") - ) - - assertEquals( - getEnumNames(transformed, ShapeId.from("test#Enum1")), - List("ENUM1_VCONFLICT") - ) - - assertEquals( - getEnumNames(original, ShapeId.from("test#Enum2")), - List("VCONFLICT") - ) - - assertEquals( - getEnumNames(transformed, ShapeId.from("test#Enum2")), - List("ENUM2_VCONFLICT") - ) - - assertEquals( - getEnumNames(original, ShapeId.from(s"a.ns#OtherEnum")), - List("VCONFLICT") - ) - - assertEquals( - getEnumNames(transformed, ShapeId.from(s"a.ns#OtherEnum")), - List("VCONFLICT") - ) - } - - private def checkTransformer(src: String, t: ProjectionTransformer)( - check: (Model, Model) => Unit - ): Unit = { - val original = buildModel(src) - val transformed = process(original, t) - check(original, transformed) - } - - private def buildModel(srcs: String*): Model = { - val assembler = Model - .assembler() - .discoverModels() - - srcs.zipWithIndex.foreach { case (s, i) => - assembler.addUnparsedModel(s"inlined-in-test.$i.smithy", s) - } - - assembler.assemble().unwrap() - } - - private def process(m: Model, t: ProjectionTransformer): Model = { - t.transform(TransformContext.builder().model(m).build()) - } -} diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/ProtoValidator.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/ProtoValidator.scala index 0e7da9c..c600696 100644 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/ProtoValidator.scala +++ b/modules/proto/tests/src/smithytranslate/proto3/internals/ProtoValidator.scala @@ -13,7 +13,7 @@ * limitations under the License. */ -package smithyproto.validation +package smithytranslate.proto3.internals import scalapb.compiler._ import protocgen.CodeGenRequest diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocInvocationHelper.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocInvocationHelper.scala index 9dda0e2..eaec786 100644 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocInvocationHelper.scala +++ b/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocInvocationHelper.scala @@ -13,7 +13,7 @@ * limitations under the License. */ -package smithyproto.validation +package smithytranslate.proto3.internals import com.google.protobuf.Descriptors.FileDescriptor import java.nio.file.Files @@ -34,25 +34,24 @@ import scala.io.Source trait ProtocInvocationHelper { private lazy val protoc = ProtocRunner.forVersion(Version.protobufVersion) - private def loadProtoFiles(names: String*): List[(String, String)] = { - val dir = "/google/protobuf/" - val path = getClass.getResource(dir) - val folder = new File(path.getPath) - if (folder.exists && folder.isDirectory) { - folder.listFiles.toList - .collect { - case file if names.contains(file.getName) => - dir + file.getName -> Source - .fromFile(file) - .getLines() - .mkString("\n") - } - } else List.empty + private def loadProtoFiles(directories: String*): List[(String, String)] = { + directories.flatMap { d => + val dir = new File(getClass.getResource(d).getPath()) + dir.listFiles().toSeq.filter(_.getName().endsWith(".proto")).map { file => + (d + "/" + file.getName) -> Source + .fromFile(file) + .getLines() + .mkString("\n") + } + }.toList } def generateFileSet(files: Seq[(String, String)]): Seq[FileDescriptor] = { val tmpDir = Files.createTempDirectory("validation").toFile - val extraFiles = loadProtoFiles("wrappers.proto", "any.proto") + val extraFiles = loadProtoFiles( + "/google/protobuf/", + "/alloy/protobuf/" + ) val allFiles = files ++ extraFiles val fileNames = allFiles.map { case (name, content) => val names = name.split("/") diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocRunner.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocRunner.scala index 32e2b65..6e962d9 100644 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocRunner.scala +++ b/modules/proto/tests/src/smithytranslate/proto3/internals/ProtocRunner.scala @@ -13,7 +13,7 @@ * limitations under the License. */ -package smithyproto.validation +package smithytranslate.proto3.internals import coursier._ import coursier.core.Extension diff --git a/modules/proto/tests/src/smithytranslate/proto3/internals/RendererSuite.scala b/modules/proto/tests/src/smithytranslate/proto3/internals/RendererSuite.scala index d7f8489..ce3bc00 100644 --- a/modules/proto/tests/src/smithytranslate/proto3/internals/RendererSuite.scala +++ b/modules/proto/tests/src/smithytranslate/proto3/internals/RendererSuite.scala @@ -13,8 +13,7 @@ * limitations under the License. */ -package smithyproto -package proto3 +package smithytranslate.proto3.internals import munit.FunSuite @@ -59,14 +58,14 @@ class RendererSuite extends FunSuite { val result = Renderer.render(unit) val expected = s"""|syntax = "proto3"; - | - |package com.example; - | - |message Foo { - | int32 a = 1; - | repeated string b = 2; - |} - |""".stripMargin + | + |package com.example; + | + |message Foo { + | int32 a = 1; + | repeated string b = 2; + |} + |""".stripMargin assertEquals(result, expected) } @@ -93,11 +92,11 @@ class RendererSuite extends FunSuite { val result = Text.renderText(Renderer.renderMessage(node)) val expected = s"""|message Foo { - | reserved 3, 5 to 8; - | reserved "c", "d"; - | int32 a = 1; - | repeated string b = 2; - |}""".stripMargin + | reserved 3, 5 to 8; + | reserved "c", "d"; + | int32 a = 1; + | repeated string b = 2; + |}""".stripMargin assertEquals(result, expected) } @@ -119,10 +118,10 @@ class RendererSuite extends FunSuite { val result = Text.renderText(Renderer.renderEnum(node)) val expected = s"""|enum SomeEnum { - | reserved 3, 5 to 8; - | reserved "c", "d"; - | V1 = 1; - |}""".stripMargin + | reserved 3, 5 to 8; + | reserved "c", "d"; + | V1 = 1; + |}""".stripMargin assertEquals(result, expected) } @@ -167,16 +166,16 @@ class RendererSuite extends FunSuite { val result = Renderer.render(unit) val expected = s"""|syntax = "proto3"; - | - |package com.example; - | - |message Foo { - | oneof foo_oneof { - | int32 a = 1; - | repeated string b = 2 [deprecated = true]; - | } - |} - |""".stripMargin + | + |package com.example; + | + |message Foo { + | oneof foo_oneof { + | int32 a = 1; + | repeated string b = 2 [deprecated = true]; + | } + |} + |""".stripMargin assertEquals(result, expected) } @@ -231,24 +230,24 @@ class RendererSuite extends FunSuite { val result = Renderer.render(unit) val expected = s"""|syntax = "proto3"; - | - |package com.example; - | - |enum TopLevelEnum { - | reserved "c", "d"; - | FALSE = 0; - | TRUE = 1; - |} - | - |message Foo { - | enum Corpus { - | reserved 3, 5 to 8; - | UNIVERSAL = 0; - | WEB = 1; - | VIDEO = 2; - | } - |} - |""".stripMargin + | + |package com.example; + | + |enum TopLevelEnum { + | reserved "c", "d"; + | FALSE = 0; + | TRUE = 1; + |} + | + |message Foo { + | enum Corpus { + | reserved 3, 5 to 8; + | UNIVERSAL = 0; + | WEB = 1; + | VIDEO = 2; + | } + |} + |""".stripMargin assertEquals(result, expected) } diff --git a/modules/readme-validator/src/Validator.scala b/modules/readme-validator/src/Validator.scala index 7a9e622..c6184bc 100644 --- a/modules/readme-validator/src/Validator.scala +++ b/modules/readme-validator/src/Validator.scala @@ -17,12 +17,12 @@ import cats.data.NonEmptyList import java.nio.file.Path import scala.jdk.CollectionConverters._ import scala.util.control.NoStackTrace -import smithyproto.proto3.{Compiler => ProtoCompiler, ModelPreProcessor} import smithytranslate.compiler._ import smithytranslate.compiler.openapi._ import smithytranslate.compiler.json_schema._ import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.SmithyIdlModelSerializer +import smithytranslate.proto3.SmithyToProtoCompiler object Validator { @@ -191,10 +191,10 @@ object Validator { val ns = "foo" ns -> s"""|$$version: "2" - | - |namespace $ns - | - |$smithy""".stripMargin + | + |namespace $ns + | + |$smithy""".stripMargin } val getActualProto: String => String = { proto => val lines = proto.split("\n") @@ -209,7 +209,6 @@ object Validator { } } val ActualProto = List(getActualProto(proto)) - val compiler = new ProtoCompiler() val inputModel = Model .assembler() .discoverModels() @@ -217,22 +216,17 @@ object Validator { .assemble() .unwrap() - val result = compiler.compile( - ModelPreProcessor( - inputModel, - List(ModelPreProcessor.transformers.PreventEnumConflicts) - ) - ) - val rendered = result + val rendered = SmithyToProtoCompiler + .withConvertAllShapes(true) + .compile(inputModel) .filter(_.path.contains(namespace)) - .map(r => smithyproto.proto3.Renderer.render(r.unit)) + .map(_.contents) .sorted - rendered match { - case Nil => List(ValidationError.UnableToProduceOutput(actualSmithy)) - case ActualProto => Nil - case other => - List(ValidationError.ProtoConversionError(other, ActualProto)) - } + + if (rendered == ActualProto) Nil + else if (rendered.isEmpty) + List(ValidationError.UnableToProduceOutput(actualSmithy)) + else List(ValidationError.ProtoConversionError(rendered, ActualProto)) } def validate( diff --git a/modules/traits/src/smithytranslate/BigDecimal.java b/modules/traits/src/smithytranslate/BigDecimal.java deleted file mode 100644 index 0f587ae..0000000 --- a/modules/traits/src/smithytranslate/BigDecimal.java +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 Disney Streaming - * - * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://disneystreaming.github.io/TOST-1.0.txt - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package smithytranslate; - -import software.amazon.smithy.model.shapes.*; -import software.amazon.smithy.model.loader.Prelude; -import software.amazon.smithy.model.traits.RequiredTrait; - -final public class BigDecimal { - static public ShapeId target = ShapeId.fromParts(Prelude.NAMESPACE, "String"); - static public Shape shape = - StructureShape - .builder() - .id("smithytranslate#BigDecimal") - .addMember( - MemberShape.builder() - .id("smithytranslate#BigDecimal$value") - .target(target) - .addTrait(new RequiredTrait()) - .build() - ) - .build(); -} diff --git a/modules/traits/src/smithytranslate/BigInteger.java b/modules/traits/src/smithytranslate/BigInteger.java deleted file mode 100644 index d845a60..0000000 --- a/modules/traits/src/smithytranslate/BigInteger.java +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 Disney Streaming - * - * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://disneystreaming.github.io/TOST-1.0.txt - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package smithytranslate; - -import software.amazon.smithy.model.shapes.*; -import software.amazon.smithy.model.loader.Prelude; -import software.amazon.smithy.model.traits.RequiredTrait; - -final public class BigInteger { - static public ShapeId target = ShapeId.fromParts(Prelude.NAMESPACE, "String"); - static public Shape shape = - StructureShape - .builder() - .id("smithytranslate#BigInteger") - .addMember( - MemberShape.builder() - .id("smithytranslate#BigInteger$value") - .target(target) - .addTrait(new RequiredTrait()) - .build() - ) - .build(); -} diff --git a/modules/traits/src/smithytranslate/Timestamp.java b/modules/traits/src/smithytranslate/Timestamp.java deleted file mode 100644 index 9e9bbe8..0000000 --- a/modules/traits/src/smithytranslate/Timestamp.java +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2022 Disney Streaming - * - * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://disneystreaming.github.io/TOST-1.0.txt - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package smithytranslate; - -import software.amazon.smithy.model.shapes.*; -import software.amazon.smithy.model.loader.Prelude; -import software.amazon.smithy.model.traits.RequiredTrait; - -final public class Timestamp { - static public ShapeId target = ShapeId.fromParts(Prelude.NAMESPACE, "Long"); - static public Shape shape = - StructureShape - .builder() - .id("smithytranslate#Timestamp") - .addMember( - MemberShape.builder() - .id("smithytranslate#Timestamp$value") - .target(target) - .addTrait(new RequiredTrait()) - .build() - ) - .build(); -} diff --git a/modules/traits/src/smithytranslate/UUID.java b/modules/traits/src/smithytranslate/UUID.java deleted file mode 100644 index f835eac..0000000 --- a/modules/traits/src/smithytranslate/UUID.java +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2022 Disney Streaming - * - * Licensed under the Tomorrow Open Source Technology License, Version 1.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://disneystreaming.github.io/TOST-1.0.txt - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package smithytranslate; - -import software.amazon.smithy.model.shapes.*; -import software.amazon.smithy.model.loader.Prelude; -import software.amazon.smithy.model.traits.RequiredTrait; - -final public class UUID { - static public Shape shape = StructureShape - .builder() - .id("smithytranslate#UUID") - .addMember( - "upper_bits", - ShapeId.fromParts("smithy.api", "Long"), - b -> b.addTrait(new RequiredTrait()) - ) - .addMember( - "lower_bits", - ShapeId.fromParts("smithy.api", "Long"), - b -> b.addTrait(new RequiredTrait()) - ) - .build(); -}