Skip to content

Commit

Permalink
[protobuf] Add map support
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Jun 12, 2024
1 parent d6082e1 commit a3a13c5
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 25 deletions.
66 changes: 50 additions & 16 deletions protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package magnolify.protobuf
import java.lang.reflect.Method
import java.util as ju
import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor}
import com.google.protobuf.{ByteString, Message, ProtocolMessageEnum}
import com.google.protobuf.{ByteString, MapEntry, Message, ProtocolMessageEnum}
import magnolia1.*
import magnolify.shared.*
import magnolify.shims.FactoryCompat
Expand Down Expand Up @@ -54,17 +54,13 @@ object ProtobufType {
r.checkDefaults(descriptor)(cm)
}

@transient private var _newBuilder: Method = _
private def newBuilder: Message.Builder = {
if (_newBuilder == null) {
_newBuilder = ct.runtimeClass.getMethod("newBuilder")
}
@transient private lazy val _newBuilder: Method = ct.runtimeClass.getMethod("newBuilder")
private def newBuilder(): Message.Builder =
_newBuilder.invoke(null).asInstanceOf[Message.Builder]
}

private val caseMapper: CaseMapper = cm
override def from(v: MsgT): T = r.from(v)(caseMapper)
override def to(v: T): MsgT = r.to(v, newBuilder)(caseMapper).asInstanceOf[MsgT]
override def to(v: T): MsgT = r.to(v, newBuilder())(caseMapper).asInstanceOf[MsgT]
}
case _ =>
throw new IllegalArgumentException(s"ProtobufType can only be created from Record. Got $f")
Expand Down Expand Up @@ -130,6 +126,10 @@ object ProtobufField {
}
)

private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder =
if (f.getType != FieldDescriptor.Type.MESSAGE) null
else b.newBuilderForField(f)

override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = {
val fields = getFields(descriptor)(cm)
caseClass.parameters.foreach { p =>
Expand Down Expand Up @@ -169,17 +169,11 @@ object ProtobufField {

override def to(v: T, bu: Message.Builder)(cm: CaseMapper): Message = {
val fields = getFields(bu.getDescriptorForType)(cm)

caseClass.parameters
.foldLeft(bu) { (b, p) =>
val field = fields(p.index)
val value = if (field.getType == FieldDescriptor.Type.MESSAGE) {
// nested records
p.typeclass.to(p.dereference(v), b.newBuilderForField(field))(cm)
} else {
// non-nested
p.typeclass.to(p.dereference(v), null)(cm)
}
val builder = newFieldBuilder(bu)(field)
val value = p.typeclass.to(p.dereference(v), builder)(cm)
if (value == null) b else b.setField(field, value)
}
.build()
Expand Down Expand Up @@ -284,4 +278,44 @@ object ProtobufField {
override def to(v: C[T], b: Message.Builder)(cm: CaseMapper): ju.List[f.ToT] =
if (v.isEmpty) null else v.iterator.map(f.to(_, b)(cm)).toList.asJava
}

implicit def pfMap[K, V](implicit
kf: ProtobufField[K],
vf: ProtobufField[V]
): ProtobufField[Map[K, V]] =
new Aux[Map[K, V], ju.List[MapEntry[kf.FromT, vf.FromT]], ju.List[MapEntry[kf.ToT, vf.ToT]]] {

override val default: Option[Map[K, V]] = Some(Map.empty)

override def from(v: ju.List[MapEntry[kf.FromT, vf.FromT]])(cm: CaseMapper): Map[K, V] = {
val b = Map.newBuilder[K, V]
if (v != null) {
b ++= v.asScala.map(me => kf.from(me.getKey)(cm) -> vf.from(me.getValue)(cm))
}
b.result()
}

private def newFieldBuilder(b: Message.Builder)(f: FieldDescriptor): Message.Builder =
if (f.getType != FieldDescriptor.Type.MESSAGE) null
else b.newBuilderForField(f)

override def to(v: Map[K, V], b: Message.Builder)(
cm: CaseMapper
): ju.List[MapEntry[kf.ToT, vf.ToT]] = {
if (v.isEmpty) {
null
} else {
val keyField = b.getDescriptorForType.findFieldByName("key")
val valueField = b.getDescriptorForType.findFieldByName("value")
v.map { case (k, v) =>
b
.setField(keyField, kf.to(k, newFieldBuilder(b)(keyField))(cm))
.setField(valueField, vf.to(v, newFieldBuilder(b)(valueField))(cm))
.build()
.asInstanceOf[MapEntry[kf.ToT, vf.ToT]]
}.toList
.asJava
}
}
}
}
9 changes: 7 additions & 2 deletions protobuf/src/test/protobuf/Proto2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ message NestedP2 {
repeated RequiredP2 l = 6;
}

message CollectionP2 {
message CollectionsP2 {
repeated int32 a = 1;
repeated int32 l = 2;
repeated int32 v = 3;
repeated int32 s = 4;
}

message MoreCollectionP2 {
message MoreCollectionsP2 {
repeated int32 i = 1;
repeated int32 s = 2;
repeated int32 is = 3;
}

message MapsP2 {
map<string, int32> mp = 1;
map<string, NestedP2> mn = 2;
}

message EnumsP2 {
enum JavaEnums {
RED = 0;
Expand Down
9 changes: 7 additions & 2 deletions protobuf/src/test/protobuf/Proto3.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ message NestedP3 {
repeated RequiredP3 l = 6;
}

message CollectionP3 {
message CollectionsP3 {
repeated int32 a = 1;
repeated int32 l = 2;
repeated int32 v = 3;
repeated int32 s = 4;
}

message MoreCollectionP3 {
message MoreCollectionsP3 {
repeated int32 i = 1;
repeated int32 s = 2;
repeated int32 is = 3;
}

message MapsP3 {
map<string, int32> mp = 1;
map<string, NestedP3> mn = 2;
}

message EnumsP3 {
enum JavaEnums {
RED = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ class ProtobufTypeSuite extends BaseProtobufTypeSuite {
test[UnsafeChar, IntegersP2]
test[UnsafeShort, IntegersP2]

test[Collections, CollectionP2]
test[Collections, CollectionP3]
test[MoreCollections, MoreCollectionP2]
test[MoreCollections, MoreCollectionP3]
test[Collections, CollectionsP2]
test[Collections, CollectionsP3]
test[MoreCollections, MoreCollectionsP2]
test[MoreCollections, MoreCollectionsP3]

test[Maps, MapsP2]
test[Maps, MapsP3]

test("AnyVal") {
test[ProtoHasValueClass, IntegersP2]
Expand Down Expand Up @@ -165,6 +168,7 @@ object Proto3Enums {
ProtobufField.enum[ADT.Color, EnumsP3.ScalaEnums]
}

case class Maps(mp: Map[String, Int], mn: Map[String, Nested])
case class ProtoValueClass(value: Long) extends AnyVal
case class ProtoHasValueClass(i: Int, l: ProtoValueClass)
case class UnsafeByte(i: Byte, l: Long)
Expand Down
9 changes: 8 additions & 1 deletion test/src/test/scala/magnolify/test/Simple.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ object Simple {
o: Option[Required],
l: List[Required]
)

case class Collections(a: Array[Int], l: List[Int], v: Vector[Int], s: Set[Int]) {

override def hashCode(): Int = {
Expand All @@ -71,7 +72,12 @@ object Simple {
case _ => false
}
}
case class MoreCollections(i: Iterable[Int], s: Seq[Int], is: IndexedSeq[Int])
case class MoreCollections(
i: Iterable[Int],
s: Seq[Int],
is: IndexedSeq[Int]
)

case class Enums(
j: JavaEnums.Color,
s: ScalaEnums.Color.Type,
Expand All @@ -94,6 +100,7 @@ object Simple {
sr: List[UnsafeEnum[ScalaEnums.Color.Type]],
ar: List[UnsafeEnum[ADT.Color]]
)

case class Custom(u: URI, d: Duration)

case class LowerCamel(firstField: String, secondField: String, innerField: LowerCamelInner)
Expand Down

0 comments on commit a3a13c5

Please sign in to comment.