From a3a13c594cbd3554021b04813bda762834732a39 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Wed, 12 Jun 2024 11:43:37 +0200 Subject: [PATCH] [protobuf] Add map support --- .../magnolify/protobuf/ProtobufType.scala | 66 ++++++++++++++----- protobuf/src/test/protobuf/Proto2.proto | 9 ++- protobuf/src/test/protobuf/Proto3.proto | 9 ++- .../protobuf/ProtobufTypeSuite.scala | 12 ++-- .../test/scala/magnolify/test/Simple.scala | 9 ++- 5 files changed, 80 insertions(+), 25 deletions(-) diff --git a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala index e83572813..b59953833 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala @@ -19,7 +19,7 @@ package magnolify.protobuf import java.lang.reflect.Method import java.util as ju import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor} -import com.google.protobuf.{ByteString, Message, ProtocolMessageEnum} +import com.google.protobuf.{ByteString, MapEntry, Message, ProtocolMessageEnum} import magnolia1.* import magnolify.shared.* import magnolify.shims.FactoryCompat @@ -54,17 +54,13 @@ object ProtobufType { r.checkDefaults(descriptor)(cm) } - @transient private var _newBuilder: Method = _ - private def newBuilder: Message.Builder = { - if (_newBuilder == null) { - _newBuilder = ct.runtimeClass.getMethod("newBuilder") - } + @transient private lazy val _newBuilder: Method = ct.runtimeClass.getMethod("newBuilder") + private def newBuilder(): Message.Builder = _newBuilder.invoke(null).asInstanceOf[Message.Builder] - } private val caseMapper: CaseMapper = cm override def from(v: MsgT): T = r.from(v)(caseMapper) - override def to(v: T): MsgT = r.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT] + override def to(v: T): MsgT = r.to(v, newBuilder())(caseMapper).asInstanceOf[MsgT] } case _ => throw new IllegalArgumentException(s"ProtobufType can only be created from Record. Got $f") @@ -130,6 +126,10 @@ object ProtobufField { } ) + private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder = + if (f.getType != FieldDescriptor.Type.MESSAGE) null + else b.newBuilderForField(f) + override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = { val fields = getFields(descriptor)(cm) caseClass.parameters.foreach { p => @@ -169,17 +169,11 @@ object ProtobufField { override def to(v: T, bu: Message.Builder)(cm: CaseMapper): Message = { val fields = getFields(bu.getDescriptorForType)(cm) - caseClass.parameters .foldLeft(bu) { (b, p) => val field = fields(p.index) - val value = if (field.getType == FieldDescriptor.Type.MESSAGE) { - // nested records - p.typeclass.to(p.dereference(v), b.newBuilderForField(field))(cm) - } else { - // non-nested - p.typeclass.to(p.dereference(v), null)(cm) - } + val builder = newFieldBuilder(bu)(field) + val value = p.typeclass.to(p.dereference(v), builder)(cm) if (value == null) b else b.setField(field, value) } .build() @@ -284,4 +278,44 @@ object ProtobufField { override def to(v: C[T], b: Message.Builder)(cm: CaseMapper): ju.List[f.ToT] = if (v.isEmpty) null else v.iterator.map(f.to(_, b)(cm)).toList.asJava } + + implicit def pfMap[K, V](implicit + kf: ProtobufField[K], + vf: ProtobufField[V] + ): ProtobufField[Map[K, V]] = + new Aux[Map[K, V], ju.List[MapEntry[kf.FromT, vf.FromT]], ju.List[MapEntry[kf.ToT, vf.ToT]]] { + + override val default: Option[Map[K, V]] = Some(Map.empty) + + override def from(v: ju.List[MapEntry[kf.FromT, vf.FromT]])(cm: CaseMapper): Map[K, V] = { + val b = Map.newBuilder[K, V] + if (v != null) { + b ++= v.asScala.map(me => kf.from(me.getKey)(cm) -> vf.from(me.getValue)(cm)) + } + b.result() + } + + private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder = + if (f.getType != FieldDescriptor.Type.MESSAGE) null + else b.newBuilderForField(f) + + override def to(v: Map[K, V], b: Message.Builder)( + cm: CaseMapper + ): ju.List[MapEntry[kf.ToT, vf.ToT]] = { + if (v.isEmpty) { + null + } else { + val keyField = b.getDescriptorForType.findFieldByName("key") + val valueField = b.getDescriptorForType.findFieldByName("value") + v.map { case (k, v) => + b + .setField(keyField, kf.to(k, newFieldBuilder(b)(keyField))(cm)) + .setField(valueField, vf.to(v, newFieldBuilder(b)(valueField))(cm)) + .build() + .asInstanceOf[MapEntry[kf.ToT, vf.ToT]] + }.toList + .asJava + } + } + } } diff --git a/protobuf/src/test/protobuf/Proto2.proto b/protobuf/src/test/protobuf/Proto2.proto index af4f12d1e..ab41afa72 100644 --- a/protobuf/src/test/protobuf/Proto2.proto +++ b/protobuf/src/test/protobuf/Proto2.proto @@ -41,19 +41,24 @@ message NestedP2 { repeated RequiredP2 l = 6; } -message CollectionP2 { +message CollectionsP2 { repeated int32 a = 1; repeated int32 l = 2; repeated int32 v = 3; repeated int32 s = 4; } -message MoreCollectionP2 { +message MoreCollectionsP2 { repeated int32 i = 1; repeated int32 s = 2; repeated int32 is = 3; } +message MapsP2 { + map mp = 1; + map mn = 2; +} + message EnumsP2 { enum JavaEnums { RED = 0; diff --git a/protobuf/src/test/protobuf/Proto3.proto b/protobuf/src/test/protobuf/Proto3.proto index 1fe02fc6f..e9a9c23ff 100644 --- a/protobuf/src/test/protobuf/Proto3.proto +++ b/protobuf/src/test/protobuf/Proto3.proto @@ -41,19 +41,24 @@ message NestedP3 { repeated RequiredP3 l = 6; } -message CollectionP3 { +message CollectionsP3 { repeated int32 a = 1; repeated int32 l = 2; repeated int32 v = 3; repeated int32 s = 4; } -message MoreCollectionP3 { +message MoreCollectionsP3 { repeated int32 i = 1; repeated int32 s = 2; repeated int32 is = 3; } +message MapsP3 { + map mp = 1; + map mn = 2; +} + message EnumsP3 { enum JavaEnums { RED = 0; diff --git a/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala b/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala index c6cb11637..c7bfd7423 100644 --- a/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala +++ b/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala @@ -72,10 +72,13 @@ class ProtobufTypeSuite extends BaseProtobufTypeSuite { test[UnsafeChar, IntegersP2] test[UnsafeShort, IntegersP2] - test[Collections, CollectionP2] - test[Collections, CollectionP3] - test[MoreCollections, MoreCollectionP2] - test[MoreCollections, MoreCollectionP3] + test[Collections, CollectionsP2] + test[Collections, CollectionsP3] + test[MoreCollections, MoreCollectionsP2] + test[MoreCollections, MoreCollectionsP3] + + test[Maps, MapsP2] + test[Maps, MapsP3] test("AnyVal") { test[ProtoHasValueClass, IntegersP2] @@ -165,6 +168,7 @@ object Proto3Enums { ProtobufField.enum[ADT.Color, EnumsP3.ScalaEnums] } +case class Maps(mp: Map[String, Int], mn: Map[String, Nested]) case class ProtoValueClass(value: Long) extends AnyVal case class ProtoHasValueClass(i: Int, l: ProtoValueClass) case class UnsafeByte(i: Byte, l: Long) diff --git a/test/src/test/scala/magnolify/test/Simple.scala b/test/src/test/scala/magnolify/test/Simple.scala index b9b03eeef..a08902ecb 100644 --- a/test/src/test/scala/magnolify/test/Simple.scala +++ b/test/src/test/scala/magnolify/test/Simple.scala @@ -51,6 +51,7 @@ object Simple { o: Option[Required], l: List[Required] ) + case class Collections(a: Array[Int], l: List[Int], v: Vector[Int], s: Set[Int]) { override def hashCode(): Int = { @@ -71,7 +72,12 @@ object Simple { case _ => false } } - case class MoreCollections(i: Iterable[Int], s: Seq[Int], is: IndexedSeq[Int]) + case class MoreCollections( + i: Iterable[Int], + s: Seq[Int], + is: IndexedSeq[Int] + ) + case class Enums( j: JavaEnums.Color, s: ScalaEnums.Color.Type, @@ -94,6 +100,7 @@ object Simple { sr: List[UnsafeEnum[ScalaEnums.Color.Type]], ar: List[UnsafeEnum[ADT.Color]] ) + case class Custom(u: URI, d: Duration) case class LowerCamel(firstField: String, secondField: String, innerField: LowerCamelInner)