From c24b4bf780a8e8c5a6a90248afced113177d1b24 Mon Sep 17 00:00:00 2001
From: Claire McGinty <clairem@spotify.com>
Date: Tue, 9 Jul 2024 09:30:25 -0700
Subject: [PATCH] Derive AvroCompat automatically on read

---
 .../magnolify/parquet/ParquetField.scala      | 61 +++++++++++--------
 .../scala/magnolify/parquet/ParquetType.scala | 27 ++++++--
 .../scala/magnolify/parquet/Predicate.scala   |  2 +-
 .../main/scala/magnolify/parquet/Schema.scala | 15 +++++
 4 files changed, 71 insertions(+), 34 deletions(-)

diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetField.scala b/parquet/src/main/scala/magnolify/parquet/ParquetField.scala
index 2f2cc902e..2458adca4 100644
--- a/parquet/src/main/scala/magnolify/parquet/ParquetField.scala
+++ b/parquet/src/main/scala/magnolify/parquet/ParquetField.scala
@@ -51,7 +51,7 @@ sealed trait ParquetField[T] extends Serializable {
   protected final def nonEmpty(v: T): Boolean = !isEmpty(v)
 
   def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit
-  def newConverter: TypeConverter[T]
+  def newConverter(writerSchema: Type): TypeConverter[T]
 
   protected def writeGroup(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = {
     if (isGroup) {
@@ -83,8 +83,9 @@ object ParquetField {
         override protected def isEmpty(v: T): Boolean = tc.isEmpty(p.dereference(v))
         override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit =
           tc.writeGroup(c, p.dereference(v))(cm)
-        override def newConverter: TypeConverter[T] = {
-          val buffered = tc.newConverter
+        override def newConverter(writerSchema: Type): TypeConverter[T] = {
+          val buffered = tc
+            .newConverter(writerSchema)
             .asInstanceOf[TypeConverter.Buffered[p.PType]]
           new TypeConverter.Delegate[p.PType, T](buffered) {
             override def get: T = inner.get(b => caseClass.construct(_ => b.head))
@@ -139,9 +140,10 @@ object ParquetField {
           }
         }
 
-        override def newConverter: TypeConverter[T] =
+        override def newConverter(writerSchema: Type): TypeConverter[T] =
           new GroupConverter with TypeConverter.Buffered[T] {
-            private val fieldConverters = caseClass.parameters.map(_.typeclass.newConverter)
+            private val fieldConverters =
+              caseClass.parameters.map(_.typeclass.newConverter(writerSchema))
 
             override def isPrimitive: Boolean = false
 
@@ -191,8 +193,9 @@ object ParquetField {
       new Primitive[U] {
         override def buildSchema(cm: CaseMapper): Type = pf.schema(cm)
         override def write(c: RecordConsumer, v: U)(cm: CaseMapper): Unit = pf.write(c, g(v))(cm)
-        override def newConverter: TypeConverter[U] =
-          pf.newConverter.asInstanceOf[TypeConverter.Primitive[T]].map(f)
+        override def newConverter(writerSchema: Type): TypeConverter[U] =
+          pf.newConverter(writerSchema).asInstanceOf[TypeConverter.Primitive[T]].map(f)
+
         override type ParquetT = pf.ParquetT
       }
   }
@@ -215,7 +218,7 @@ object ParquetField {
     new Primitive[T] {
       override def buildSchema(cm: CaseMapper): Type = Schema.primitive(ptn, lta)
       override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = f(c)(v)
-      override def newConverter: TypeConverter[T] = g
+      override def newConverter(writerSchema: Type): TypeConverter[T] = g
       override type ParquetT = UnderlyingT
     }
 
@@ -291,8 +294,9 @@ object ParquetField {
       override def write(c: RecordConsumer, v: Option[T])(cm: CaseMapper): Unit =
         v.foreach(t.writeGroup(c, _)(cm))
 
-      override def newConverter: TypeConverter[Option[T]] = {
-        val buffered = t.newConverter
+      override def newConverter(writerSchema: Type): TypeConverter[Option[T]] = {
+        val buffered = t
+          .newConverter(writerSchema)
           .asInstanceOf[TypeConverter.Buffered[T]]
           .withRepetition(Repetition.OPTIONAL)
         new TypeConverter.Delegate[T, Option[T]](buffered) {
@@ -339,15 +343,16 @@ object ParquetField {
           v.foreach(t.writeGroup(c, _)(cm))
         }
 
-      override def newConverter: TypeConverter[C[T]] = {
-        val buffered = t.newConverter
+      override def newConverter(writerSchema: Type): TypeConverter[C[T]] = {
+        val buffered = t
+          .newConverter(writerSchema)
           .asInstanceOf[TypeConverter.Buffered[T]]
           .withRepetition(Repetition.REPEATED)
         val arrayConverter = new TypeConverter.Delegate[T, C[T]](buffered) {
           override def get: C[T] = inner.get(fc.fromSpecific)
         }
 
-        if (hasAvroArray) {
+        if (Schema.hasGroupedArray(writerSchema)) {
           new GroupConverter with TypeConverter.Buffered[C[T]] {
             override def getConverter(fieldIndex: Int): Converter = {
               require(fieldIndex == 0, "Avro array field index != 0")
@@ -421,10 +426,10 @@ object ParquetField {
         }
       }
 
-      override def newConverter: TypeConverter[Map[K, V]] = {
+      override def newConverter(writerSchema: Type): TypeConverter[Map[K, V]] = {
         val kvConverter = new GroupConverter with TypeConverter.Buffered[(K, V)] {
-          private val keyConverter = pfKey.newConverter
-          private val valueConverter = pfValue.newConverter
+          private val keyConverter = pfKey.newConverter(writerSchema)
+          private val valueConverter = pfValue.newConverter(writerSchema)
           private val fieldConverters = Array(keyConverter, valueConverter)
 
           override def isPrimitive: Boolean = false
@@ -466,8 +471,8 @@ object ParquetField {
     def apply[U](f: T => U)(g: U => T)(implicit pf: Primitive[T]): Primitive[U] = new Primitive[U] {
       override def buildSchema(cm: CaseMapper): Type = Schema.setLogicalType(pf.schema(cm), lta)
       override def write(c: RecordConsumer, v: U)(cm: CaseMapper): Unit = pf.write(c, g(v))(cm)
-      override def newConverter: TypeConverter[U] =
-        pf.newConverter.asInstanceOf[TypeConverter.Primitive[T]].map(f)
+      override def newConverter(writerSchema: Type): TypeConverter[U] =
+        pf.newConverter(writerSchema).asInstanceOf[TypeConverter.Primitive[T]].map(f)
 
       override type ParquetT = pf.ParquetT
     }
@@ -509,9 +514,10 @@ object ParquetField {
       override def write(c: RecordConsumer, v: BigDecimal)(cm: CaseMapper): Unit =
         c.addBinary(Binary.fromConstantByteArray(Decimal.toFixed(v, precision, scale, length)))
 
-      override def newConverter: TypeConverter[BigDecimal] = TypeConverter.newByteArray.map { ba =>
-        Decimal.fromBytes(ba, precision, scale)
-      }
+      override def newConverter(writerSchema: Type): TypeConverter[BigDecimal] =
+        TypeConverter.newByteArray.map { ba =>
+          Decimal.fromBytes(ba, precision, scale)
+        }
 
       override type ParquetT = Binary
     }
@@ -544,12 +550,13 @@ object ParquetField {
         )
       )
 
-    override def newConverter: TypeConverter[UUID] = TypeConverter.newByteArray.map { ba =>
-      val bb = ByteBuffer.wrap(ba)
-      val h = bb.getLong
-      val l = bb.getLong
-      new UUID(h, l)
-    }
+    override def newConverter(writerSchema: Type): TypeConverter[UUID] =
+      TypeConverter.newByteArray.map { ba =>
+        val bb = ByteBuffer.wrap(ba)
+        val h = bb.getLong
+        val l = bb.getLong
+        new UUID(h, l)
+      }
 
     override type ParquetT = Binary
   }
diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala
index 6003bfaf8..829bfcd58 100644
--- a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala
+++ b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala
@@ -30,7 +30,7 @@ import org.apache.parquet.hadoop.{
 }
 import org.apache.parquet.io.api._
 import org.apache.parquet.io.{InputFile, OutputFile}
-import org.apache.parquet.schema.MessageType
+import org.apache.parquet.schema.{MessageType, Type}
 import org.slf4j.LoggerFactory
 import org.typelevel.scalaccompat.annotation.nowarn
 
@@ -73,7 +73,7 @@ sealed trait ParquetType[T] extends Serializable {
   def writeBuilder(file: OutputFile): WriteBuilder[T] = new WriteBuilder(file, writeSupport)
 
   def write(c: RecordConsumer, v: T): Unit = ()
-  def newConverter: TypeConverter[T] = null
+  def newConverter(writerSchema: Type): TypeConverter[T] = null
 }
 
 object ParquetType {
@@ -97,8 +97,10 @@ object ParquetType {
 
         override val avroCompat: Boolean =
           pa == ParquetArray.AvroCompat.avroCompat || f.hasAvroArray
+
         override def write(c: RecordConsumer, v: T): Unit = r.write(c, v)(cm)
-        override def newConverter: TypeConverter[T] = r.newConverter
+        override def newConverter(writerSchema: Type): TypeConverter[T] =
+          r.newConverter(writerSchema)
       }
     case _ =>
       throw new IllegalArgumentException(s"ParquetType can only be created from Record. Got $f")
@@ -151,9 +153,22 @@ object ParquetType {
         )
       }
 
-      val requestedSchema = Schema.message(parquetType.schema)
+      val requestedSchema = {
+        val s = Schema.message(parquetType.schema)
+        // If reading Avro, roundtrip schema using parquet-avro converter to ensure array compatibility;
+        // magnolify-parquet does not automatically wrap repeated fields into a group like parquet-avro does
+        if (isAvroFile) {
+          val converter = new AvroSchemaConverter()
+          converter.convert(converter.convert(s))
+        } else {
+          s
+        }
+      }
       Schema.checkCompatibility(context.getFileSchema, requestedSchema)
-      new hadoop.ReadSupport.ReadContext(requestedSchema, java.util.Collections.emptyMap())
+      new hadoop.ReadSupport.ReadContext(
+        requestedSchema,
+        java.util.Collections.emptyMap()
+      )
     }
 
     override def prepareForRead(
@@ -163,7 +178,7 @@ object ParquetType {
       readContext: hadoop.ReadSupport.ReadContext
     ): RecordMaterializer[T] =
       new RecordMaterializer[T] {
-        private val root = parquetType.newConverter
+        private val root = parquetType.newConverter(fileSchema)
         override def getCurrentRecord: T = root.get
         override def getRootConverter: GroupConverter = root.asGroupConverter()
       }
diff --git a/parquet/src/main/scala/magnolify/parquet/Predicate.scala b/parquet/src/main/scala/magnolify/parquet/Predicate.scala
index 83d9dd4a3..a1e0ce927 100644
--- a/parquet/src/main/scala/magnolify/parquet/Predicate.scala
+++ b/parquet/src/main/scala/magnolify/parquet/Predicate.scala
@@ -65,7 +65,7 @@ object Predicate {
     }
 
     def wrap[T](addFn: (PrimitiveConverter, T) => Unit): T => ScalaFieldT = {
-      lazy val converter = pf.newConverter
+      lazy val converter = pf.newConverter(pf.schema(CaseMapper.identity))
       value => {
         addFn(converter.asPrimitiveConverter(), value)
         converter.get
diff --git a/parquet/src/main/scala/magnolify/parquet/Schema.scala b/parquet/src/main/scala/magnolify/parquet/Schema.scala
index e978b22d6..e58fae85b 100644
--- a/parquet/src/main/scala/magnolify/parquet/Schema.scala
+++ b/parquet/src/main/scala/magnolify/parquet/Schema.scala
@@ -95,6 +95,21 @@ private object Schema {
     builder.named(schema.getName)
   }
 
+  // Check if writer schema encodes arrays as a single repeated field inside of an optional or required group
+  private[parquet] def hasGroupedArray(writer: Type): Boolean =
+    !writer.isPrimitive && writer.asGroupType().getFields.asScala.exists {
+      case f if isGroupedArrayType(f) => true
+      case f if !f.isPrimitive        => f.asGroupType().getFields.asScala.exists(hasGroupedArray)
+      case _                          => false
+    }
+
+  private def isGroupedArrayType(f: Type): Boolean =
+    !f.isPrimitive &&
+      f.getLogicalTypeAnnotation == LogicalTypeAnnotation.listType() && {
+        val fields = f.asGroupType().getFields.asScala
+        fields.size == 1 && fields.head.isRepetition(Repetition.REPEATED)
+      }
+
   def checkCompatibility(writer: Type, reader: Type): Unit = {
     def listFields(gt: GroupType) =
       s"[${gt.getFields.asScala.map(f => s"${f.getName}: ${f.getRepetition}").mkString(",")}]"