From 5cbdfe28303add11b0b665811216b4542112310c Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Tue, 23 Jul 2024 09:30:07 +0200 Subject: [PATCH] [protobuf] Add map support (#988) --- docs/mapping.md | 58 ++++++++-------- .../magnolify/protobuf/ProtobufType.scala | 66 ++++++++++++++----- protobuf/src/test/protobuf/Proto2.proto | 9 ++- protobuf/src/test/protobuf/Proto3.proto | 9 ++- .../protobuf/ProtobufTypeSuite.scala | 12 ++-- .../test/scala/magnolify/test/Simple.scala | 9 ++- 6 files changed, 109 insertions(+), 54 deletions(-) diff --git a/docs/mapping.md b/docs/mapping.md index 91fb93ba0..bafde4b65 100644 --- a/docs/mapping.md +++ b/docs/mapping.md @@ -2,38 +2,37 @@ | Scala | Avro | BigQuery | Bigtable7 | Datastore | Parquet | Protobuf | TensorFlow | |-----------------------------------|------------------------------|------------------------|---------------------------------|-----------------------|-----------------------------------|-------------------------|---------------------| -| `Unit` | `NULL` | x | x | `Null` | x | x | x | -| `Boolean` | `BOOLEAN` | `BOOL` | `Byte` | `Boolean` | `BOOLEAN` | `Boolean` | `INT64`3 | -| `Char` | `INT`3 | `INT64`3 | `Char` | `Integer`3 | `INT32`3 | `Int`3 | `INT64`3 | -| `Byte` | `INT`3 | `INT64`3 | `Byte` | `Integer`3 | `INT32`9 | `Int`3 | `INT64`3 | -| `Short` | `INT`3 | `INT64`3 | `Short` | `Integer`3 | `INT32`9 | `Int`3 | `INT64`3 | -| `Int` | `INT` | `INT64`3 | `Int` | `Integer`3 | `INT32`9 | `Int` | `INT64`3 | -| `Long` | `LONG` | `INT64` | `Long` | `Integer` | `INT64`9 | `Long` | `INT64` | -| `Float` | `FLOAT` | `FLOAT64`3 | `Float` | `Double`3 | `FLOAT` | `Float` | `FLOAT` | -| `Double` | `DOUBLE` | `FLOAT64` | `Double` | `Double` | `DOUBLE` | `Double` | `FLOAT`3 | -| `CharSequence` | `STRING` | x | x | x | x | x | x | -| `String` | `STRING` | `STRING` | `String` | `String` | `BINARY` | `String` | `BYTES`3 | -| `Array[Byte]` | `BYTES` | `BYTES` | `ByteString` | `Blob` | `BINARY` | `ByteString` | `BYTES` | +| `Unit` | `null` | x | x | `Null` | x | x | x | +| `Boolean` | `boolean` | `BOOL` | `Byte` | `Boolean` | `BOOLEAN` | `Boolean` | `INT64`3 | +| `Char` | `int`3 | `INT64`3 | `Char` | `Integer`3 | `INT32`3 | `Int`3 | `INT64`3 | +| `Byte` | `int`3 | `INT64`3 | `Byte` | `Integer`3 | `INT32`9 | `Int`3 | `INT64`3 | +| `Short` | `int`3 | `INT64`3 | `Short` | `Integer`3 | `INT32`9 | `Int`3 | `INT64`3 | +| `Int` | `int` | `INT64`3 | `Int` | `Integer`3 | `INT32`9 | `Int` | `INT64`3 | +| `Long` | `long` | `INT64` | `Long` | `Integer` | `INT64`9 | `Long` | `INT64` | +| `Float` | `float` | `FLOAT64`3 | `Float` | `Double`3 | `FLOAT` | `Float` | `FLOAT` | +| `Double` | `double` | `FLOAT64` | `Double` | `Double` | `DOUBLE` | `Double` | `FLOAT`3 | +| `CharSequence` | `string` | x | x | x | x | x | x | +| `String` | `string` | `STRING` | `String` | `String` | `BINARY` | `String` | `BYTES`3 | +| `Array[Byte]` | `bytes` | `BYTES` | `ByteString` | `Blob` | `BINARY` | `ByteString` | `BYTES` | | `ByteString` | x | x | `ByteString` | `Blob` | x | `ByteString` | `BYTES` | -| `ByteBuffer` | `BYTES` | x | x | | x | x | x | -| Enum1 | `ENUM` | `STRING`3 | `String` | `String`3 | `BINARY`/`ENUM`9 | Enum | `BYTES`3 | +| `ByteBuffer` | `bytes` | x | x | | x | x | x | +| Enum1 | `enum` | `STRING`3 | `String` | `String`3 | `BINARY`/`ENUM`9 | Enum | `BYTES`3 | | `BigInt` | x | x | `BigInt` | x | x | x | x | -| `BigDecimal` | `BYTES`4 | `NUMERIC`6 | `Int` scale + unscaled `BigInt` | x | `LOGICAL[DECIMAL]`9,14 | x | x | -| `Option[T]` | `UNION[NULL, T]`5 | `NULLABLE` | Empty as `None` | Absent as `None` | `OPTIONAL` | `optional`10 | Size <= 1 | -| `Iterable[T]`2 | `ARRAY` | `REPEATED` | x | `Array` | `REPEATED`13 | `repeated` | Size >= 0 | -| Nested | `RECORD` | `STRUCT` | Flat8 | `Entity` | Group | `Message` | Flat8 | -| `Map[CharSequence, T]` | `MAP[STRING, T]` | x | x | x | x | x | | -| `Map[String, T]` | `MAP[STRING, T]` | x | x | x | x | x | x | -| `java.time.Instant` | `LONG`11 | `TIMESTAMP` | x | `Timestamp` | `LOGICAL[TIMESTAMP]`9 | x | x | -| `java.time.LocalDateTime` | `LONG`11 | `DATETIME` | x | x | `LOGICAL[TIMESTAMP]`9 | x | x | +| `BigDecimal` | `bytes`4 | `NUMERIC`6 | `Int` scale + unscaled `BigInt` | x | `LOGICAL[DECIMAL]`9,14 | x | x | +| `Option[T]` | `union[null, T]`5 | `NULLABLE` | Empty as `None` | Absent as `None` | `OPTIONAL` | `optional`10 | Size <= 1 | +| `Iterable[T]`2 | `array[T]` | `REPEATED` | x | `Array` | `REPEATED`13 | `repeated` | Size >= 0 | +| Nested | `record` | `STRUCT` | Flat8 | `Entity` | Group | `Message` | Flat8 | +| `Map[K, V]` | `map[V]`15 | x | x | x | x | `map` | x | +| `java.time.Instant` | `long`11 | `TIMESTAMP` | x | `Timestamp` | `LOGICAL[TIMESTAMP]`9 | x | x | +| `java.time.LocalDateTime` | `long`11 | `DATETIME` | x | x | `LOGICAL[TIMESTAMP]`9 | x | x | | `java.time.OffsetTime` | x | x | x | x | `LOGICAL[TIME]`9 | x | x | -| `java.time.LocalTime` | `LONG`11 | `TIME` | x | x | `LOGICAL[TIME]`9 | x | x | -| `java.time.LocalDate` | `INT`11 | `DATE` | x | x | `LOGICAL[DATE]`9 | x | x | -| `org.joda.time.LocalDate` | `INT`11 | x | x | x | x | x | x | -| `org.joda.time.DateTime` | `INT`11 | x | x | x | x | x | x | -| `org.joda.time.LocalTime` | `INT`11 | x | x | x | x | x | x | -| `java.util.UUID` | `STRING`4 | x | ByteString (16 bytes) | x | `FIXED[16]` | x | x | -| `(Long, Long, Long)`12 | `FIXED[12]` | x | x | x | x | x | x | +| `java.time.LocalTime` | `long`11 | `TIME` | x | x | `LOGICAL[TIME]`9 | x | x | +| `java.time.LocalDate` | `int`11 | `DATE` | x | x | `LOGICAL[DATE]`9 | x | x | +| `org.joda.time.LocalDate` | `int`11 | x | x | x | x | x | x | +| `org.joda.time.DateTime` | `int`11 | x | x | x | x | x | x | +| `org.joda.time.LocalTime` | `int`11 | x | x | x | x | x | x | +| `java.util.UUID` | `string`4 | x | ByteString (16 bytes) | x | `FIXED[16]` | x | x | +| `(Long, Long, Long)`12 | `fixed[12]` | x | x | x | x | x | x | 1. Those wrapped in`UnsafeEnum` are encoded as strings, see [enums.md](https://github.com/spotify/magnolify/blob/master/docs/enums.md) for more @@ -59,3 +58,4 @@ format: `required group $FIELDNAME (LIST) { repeated $FIELDTYPE array ($FIELDSCHEMA); }`. 14. Parquet's Decimal logical format supports multiple representations, and are not implicitly scoped by default. Import one of: `magnolify.parquet.ParquetField.{decimal32, decimal64, decimalFixed, decimalBinary}`. +15. Map key type in avro is fixed to string. Scala Map key type must be either `String` or `CharSequence`. \ No newline at end of file diff --git a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala index e83572813..b59953833 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala @@ -19,7 +19,7 @@ package magnolify.protobuf import java.lang.reflect.Method import java.util as ju import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor} -import com.google.protobuf.{ByteString, Message, ProtocolMessageEnum} +import com.google.protobuf.{ByteString, MapEntry, Message, ProtocolMessageEnum} import magnolia1.* import magnolify.shared.* import magnolify.shims.FactoryCompat @@ -54,17 +54,13 @@ object ProtobufType { r.checkDefaults(descriptor)(cm) } - @transient private var _newBuilder: Method = _ - private def newBuilder: Message.Builder = { - if (_newBuilder == null) { - _newBuilder = ct.runtimeClass.getMethod("newBuilder") - } + @transient private lazy val _newBuilder: Method = ct.runtimeClass.getMethod("newBuilder") + private def newBuilder(): Message.Builder = _newBuilder.invoke(null).asInstanceOf[Message.Builder] - } private val caseMapper: CaseMapper = cm override def from(v: MsgT): T = r.from(v)(caseMapper) - override def to(v: T): MsgT = r.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT] + override def to(v: T): MsgT = r.to(v, newBuilder())(caseMapper).asInstanceOf[MsgT] } case _ => throw new IllegalArgumentException(s"ProtobufType can only be created from Record. Got $f") @@ -130,6 +126,10 @@ object ProtobufField { } ) + private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder = + if (f.getType != FieldDescriptor.Type.MESSAGE) null + else b.newBuilderForField(f) + override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = { val fields = getFields(descriptor)(cm) caseClass.parameters.foreach { p => @@ -169,17 +169,11 @@ object ProtobufField { override def to(v: T, bu: Message.Builder)(cm: CaseMapper): Message = { val fields = getFields(bu.getDescriptorForType)(cm) - caseClass.parameters .foldLeft(bu) { (b, p) => val field = fields(p.index) - val value = if (field.getType == FieldDescriptor.Type.MESSAGE) { - // nested records - p.typeclass.to(p.dereference(v), b.newBuilderForField(field))(cm) - } else { - // non-nested - p.typeclass.to(p.dereference(v), null)(cm) - } + val builder = newFieldBuilder(bu)(field) + val value = p.typeclass.to(p.dereference(v), builder)(cm) if (value == null) b else b.setField(field, value) } .build() @@ -284,4 +278,44 @@ object ProtobufField { override def to(v: C[T], b: Message.Builder)(cm: CaseMapper): ju.List[f.ToT] = if (v.isEmpty) null else v.iterator.map(f.to(_, b)(cm)).toList.asJava } + + implicit def pfMap[K, V](implicit + kf: ProtobufField[K], + vf: ProtobufField[V] + ): ProtobufField[Map[K, V]] = + new Aux[Map[K, V], ju.List[MapEntry[kf.FromT, vf.FromT]], ju.List[MapEntry[kf.ToT, vf.ToT]]] { + + override val default: Option[Map[K, V]] = Some(Map.empty) + + override def from(v: ju.List[MapEntry[kf.FromT, vf.FromT]])(cm: CaseMapper): Map[K, V] = { + val b = Map.newBuilder[K, V] + if (v != null) { + b ++= v.asScala.map(me => kf.from(me.getKey)(cm) -> vf.from(me.getValue)(cm)) + } + b.result() + } + + private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder = + if (f.getType != FieldDescriptor.Type.MESSAGE) null + else b.newBuilderForField(f) + + override def to(v: Map[K, V], b: Message.Builder)( + cm: CaseMapper + ): ju.List[MapEntry[kf.ToT, vf.ToT]] = { + if (v.isEmpty) { + null + } else { + val keyField = b.getDescriptorForType.findFieldByName("key") + val valueField = b.getDescriptorForType.findFieldByName("value") + v.map { case (k, v) => + b + .setField(keyField, kf.to(k, newFieldBuilder(b)(keyField))(cm)) + .setField(valueField, vf.to(v, newFieldBuilder(b)(valueField))(cm)) + .build() + .asInstanceOf[MapEntry[kf.ToT, vf.ToT]] + }.toList + .asJava + } + } + } } diff --git a/protobuf/src/test/protobuf/Proto2.proto b/protobuf/src/test/protobuf/Proto2.proto index af4f12d1e..ab41afa72 100644 --- a/protobuf/src/test/protobuf/Proto2.proto +++ b/protobuf/src/test/protobuf/Proto2.proto @@ -41,19 +41,24 @@ message NestedP2 { repeated RequiredP2 l = 6; } -message CollectionP2 { +message CollectionsP2 { repeated int32 a = 1; repeated int32 l = 2; repeated int32 v = 3; repeated int32 s = 4; } -message MoreCollectionP2 { +message MoreCollectionsP2 { repeated int32 i = 1; repeated int32 s = 2; repeated int32 is = 3; } +message MapsP2 { + map mp = 1; + map mn = 2; +} + message EnumsP2 { enum JavaEnums { RED = 0; diff --git a/protobuf/src/test/protobuf/Proto3.proto b/protobuf/src/test/protobuf/Proto3.proto index 1fe02fc6f..e9a9c23ff 100644 --- a/protobuf/src/test/protobuf/Proto3.proto +++ b/protobuf/src/test/protobuf/Proto3.proto @@ -41,19 +41,24 @@ message NestedP3 { repeated RequiredP3 l = 6; } -message CollectionP3 { +message CollectionsP3 { repeated int32 a = 1; repeated int32 l = 2; repeated int32 v = 3; repeated int32 s = 4; } -message MoreCollectionP3 { +message MoreCollectionsP3 { repeated int32 i = 1; repeated int32 s = 2; repeated int32 is = 3; } +message MapsP3 { + map mp = 1; + map mn = 2; +} + message EnumsP3 { enum JavaEnums { RED = 0; diff --git a/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala b/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala index c6cb11637..c7bfd7423 100644 --- a/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala +++ b/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala @@ -72,10 +72,13 @@ class ProtobufTypeSuite extends BaseProtobufTypeSuite { test[UnsafeChar, IntegersP2] test[UnsafeShort, IntegersP2] - test[Collections, CollectionP2] - test[Collections, CollectionP3] - test[MoreCollections, MoreCollectionP2] - test[MoreCollections, MoreCollectionP3] + test[Collections, CollectionsP2] + test[Collections, CollectionsP3] + test[MoreCollections, MoreCollectionsP2] + test[MoreCollections, MoreCollectionsP3] + + test[Maps, MapsP2] + test[Maps, MapsP3] test("AnyVal") { test[ProtoHasValueClass, IntegersP2] @@ -165,6 +168,7 @@ object Proto3Enums { ProtobufField.enum[ADT.Color, EnumsP3.ScalaEnums] } +case class Maps(mp: Map[String, Int], mn: Map[String, Nested]) case class ProtoValueClass(value: Long) extends AnyVal case class ProtoHasValueClass(i: Int, l: ProtoValueClass) case class UnsafeByte(i: Byte, l: Long) diff --git a/test/src/test/scala/magnolify/test/Simple.scala b/test/src/test/scala/magnolify/test/Simple.scala index b9b03eeef..a08902ecb 100644 --- a/test/src/test/scala/magnolify/test/Simple.scala +++ b/test/src/test/scala/magnolify/test/Simple.scala @@ -51,6 +51,7 @@ object Simple { o: Option[Required], l: List[Required] ) + case class Collections(a: Array[Int], l: List[Int], v: Vector[Int], s: Set[Int]) { override def hashCode(): Int = { @@ -71,7 +72,12 @@ object Simple { case _ => false } } - case class MoreCollections(i: Iterable[Int], s: Seq[Int], is: IndexedSeq[Int]) + case class MoreCollections( + i: Iterable[Int], + s: Seq[Int], + is: IndexedSeq[Int] + ) + case class Enums( j: JavaEnums.Color, s: ScalaEnums.Color.Type, @@ -94,6 +100,7 @@ object Simple { sr: List[UnsafeEnum[ScalaEnums.Color.Type]], ar: List[UnsafeEnum[ADT.Color]] ) + case class Custom(u: URI, d: Duration) case class LowerCamel(firstField: String, secondField: String, innerField: LowerCamelInner)