diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetField.scala b/parquet/src/main/scala/magnolify/parquet/ParquetField.scala index 2f2cc902..2458adca 100644 --- a/parquet/src/main/scala/magnolify/parquet/ParquetField.scala +++ b/parquet/src/main/scala/magnolify/parquet/ParquetField.scala @@ -51,7 +51,7 @@ sealed trait ParquetField[T] extends Serializable { protected final def nonEmpty(v: T): Boolean = !isEmpty(v) def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit - def newConverter: TypeConverter[T] + def newConverter(writerSchema: Type): TypeConverter[T] protected def writeGroup(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = { if (isGroup) { @@ -83,8 +83,9 @@ object ParquetField { override protected def isEmpty(v: T): Boolean = tc.isEmpty(p.dereference(v)) override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = tc.writeGroup(c, p.dereference(v))(cm) - override def newConverter: TypeConverter[T] = { - val buffered = tc.newConverter + override def newConverter(writerSchema: Type): TypeConverter[T] = { + val buffered = tc + .newConverter(writerSchema) .asInstanceOf[TypeConverter.Buffered[p.PType]] new TypeConverter.Delegate[p.PType, T](buffered) { override def get: T = inner.get(b => caseClass.construct(_ => b.head)) @@ -139,9 +140,10 @@ object ParquetField { } } - override def newConverter: TypeConverter[T] = + override def newConverter(writerSchema: Type): TypeConverter[T] = new GroupConverter with TypeConverter.Buffered[T] { - private val fieldConverters = caseClass.parameters.map(_.typeclass.newConverter) + private val fieldConverters = + caseClass.parameters.map(_.typeclass.newConverter(writerSchema)) override def isPrimitive: Boolean = false @@ -191,8 +193,9 @@ object ParquetField { new Primitive[U] { override def buildSchema(cm: CaseMapper): Type = pf.schema(cm) override def write(c: RecordConsumer, v: U)(cm: CaseMapper): Unit = pf.write(c, g(v))(cm) - override def newConverter: TypeConverter[U] = - pf.newConverter.asInstanceOf[TypeConverter.Primitive[T]].map(f) + override def newConverter(writerSchema: Type): TypeConverter[U] = + pf.newConverter(writerSchema).asInstanceOf[TypeConverter.Primitive[T]].map(f) + override type ParquetT = pf.ParquetT } } @@ -215,7 +218,7 @@ object ParquetField { new Primitive[T] { override def buildSchema(cm: CaseMapper): Type = Schema.primitive(ptn, lta) override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = f(c)(v) - override def newConverter: TypeConverter[T] = g + override def newConverter(writerSchema: Type): TypeConverter[T] = g override type ParquetT = UnderlyingT } @@ -291,8 +294,9 @@ object ParquetField { override def write(c: RecordConsumer, v: Option[T])(cm: CaseMapper): Unit = v.foreach(t.writeGroup(c, _)(cm)) - override def newConverter: TypeConverter[Option[T]] = { - val buffered = t.newConverter + override def newConverter(writerSchema: Type): TypeConverter[Option[T]] = { + val buffered = t + .newConverter(writerSchema) .asInstanceOf[TypeConverter.Buffered[T]] .withRepetition(Repetition.OPTIONAL) new TypeConverter.Delegate[T, Option[T]](buffered) { @@ -339,15 +343,16 @@ object ParquetField { v.foreach(t.writeGroup(c, _)(cm)) } - override def newConverter: TypeConverter[C[T]] = { - val buffered = t.newConverter + override def newConverter(writerSchema: Type): TypeConverter[C[T]] = { + val buffered = t + .newConverter(writerSchema) .asInstanceOf[TypeConverter.Buffered[T]] .withRepetition(Repetition.REPEATED) val arrayConverter = new TypeConverter.Delegate[T, C[T]](buffered) { override def get: C[T] = inner.get(fc.fromSpecific) } - if (hasAvroArray) { + if (Schema.hasGroupedArray(writerSchema)) { new GroupConverter with TypeConverter.Buffered[C[T]] { override def getConverter(fieldIndex: Int): Converter = { require(fieldIndex == 0, "Avro array field index != 0") @@ -421,10 +426,10 @@ object ParquetField { } } - override def newConverter: TypeConverter[Map[K, V]] = { + override def newConverter(writerSchema: Type): TypeConverter[Map[K, V]] = { val kvConverter = new GroupConverter with TypeConverter.Buffered[(K, V)] { - private val keyConverter = pfKey.newConverter - private val valueConverter = pfValue.newConverter + private val keyConverter = pfKey.newConverter(writerSchema) + private val valueConverter = pfValue.newConverter(writerSchema) private val fieldConverters = Array(keyConverter, valueConverter) override def isPrimitive: Boolean = false @@ -466,8 +471,8 @@ object ParquetField { def apply[U](f: T => U)(g: U => T)(implicit pf: Primitive[T]): Primitive[U] = new Primitive[U] { override def buildSchema(cm: CaseMapper): Type = Schema.setLogicalType(pf.schema(cm), lta) override def write(c: RecordConsumer, v: U)(cm: CaseMapper): Unit = pf.write(c, g(v))(cm) - override def newConverter: TypeConverter[U] = - pf.newConverter.asInstanceOf[TypeConverter.Primitive[T]].map(f) + override def newConverter(writerSchema: Type): TypeConverter[U] = + pf.newConverter(writerSchema).asInstanceOf[TypeConverter.Primitive[T]].map(f) override type ParquetT = pf.ParquetT } @@ -509,9 +514,10 @@ object ParquetField { override def write(c: RecordConsumer, v: BigDecimal)(cm: CaseMapper): Unit = c.addBinary(Binary.fromConstantByteArray(Decimal.toFixed(v, precision, scale, length))) - override def newConverter: TypeConverter[BigDecimal] = TypeConverter.newByteArray.map { ba => - Decimal.fromBytes(ba, precision, scale) - } + override def newConverter(writerSchema: Type): TypeConverter[BigDecimal] = + TypeConverter.newByteArray.map { ba => + Decimal.fromBytes(ba, precision, scale) + } override type ParquetT = Binary } @@ -544,12 +550,13 @@ object ParquetField { ) ) - override def newConverter: TypeConverter[UUID] = TypeConverter.newByteArray.map { ba => - val bb = ByteBuffer.wrap(ba) - val h = bb.getLong - val l = bb.getLong - new UUID(h, l) - } + override def newConverter(writerSchema: Type): TypeConverter[UUID] = + TypeConverter.newByteArray.map { ba => + val bb = ByteBuffer.wrap(ba) + val h = bb.getLong + val l = bb.getLong + new UUID(h, l) + } override type ParquetT = Binary } diff --git a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala index 6003bfaf..829bfcd5 100644 --- a/parquet/src/main/scala/magnolify/parquet/ParquetType.scala +++ b/parquet/src/main/scala/magnolify/parquet/ParquetType.scala @@ -30,7 +30,7 @@ import org.apache.parquet.hadoop.{ } import org.apache.parquet.io.api._ import org.apache.parquet.io.{InputFile, OutputFile} -import org.apache.parquet.schema.MessageType +import org.apache.parquet.schema.{MessageType, Type} import org.slf4j.LoggerFactory import org.typelevel.scalaccompat.annotation.nowarn @@ -73,7 +73,7 @@ sealed trait ParquetType[T] extends Serializable { def writeBuilder(file: OutputFile): WriteBuilder[T] = new WriteBuilder(file, writeSupport) def write(c: RecordConsumer, v: T): Unit = () - def newConverter: TypeConverter[T] = null + def newConverter(writerSchema: Type): TypeConverter[T] = null } object ParquetType { @@ -97,8 +97,10 @@ object ParquetType { override val avroCompat: Boolean = pa == ParquetArray.AvroCompat.avroCompat || f.hasAvroArray + override def write(c: RecordConsumer, v: T): Unit = r.write(c, v)(cm) - override def newConverter: TypeConverter[T] = r.newConverter + override def newConverter(writerSchema: Type): TypeConverter[T] = + r.newConverter(writerSchema) } case _ => throw new IllegalArgumentException(s"ParquetType can only be created from Record. Got $f") @@ -151,9 +153,22 @@ object ParquetType { ) } - val requestedSchema = Schema.message(parquetType.schema) + val requestedSchema = { + val s = Schema.message(parquetType.schema) + // If reading Avro, roundtrip schema using parquet-avro converter to ensure array compatibility; + // magnolify-parquet does not automatically wrap repeated fields into a group like parquet-avro does + if (isAvroFile) { + val converter = new AvroSchemaConverter() + converter.convert(converter.convert(s)) + } else { + s + } + } Schema.checkCompatibility(context.getFileSchema, requestedSchema) - new hadoop.ReadSupport.ReadContext(requestedSchema, java.util.Collections.emptyMap()) + new hadoop.ReadSupport.ReadContext( + requestedSchema, + java.util.Collections.emptyMap() + ) } override def prepareForRead( @@ -163,7 +178,7 @@ object ParquetType { readContext: hadoop.ReadSupport.ReadContext ): RecordMaterializer[T] = new RecordMaterializer[T] { - private val root = parquetType.newConverter + private val root = parquetType.newConverter(fileSchema) override def getCurrentRecord: T = root.get override def getRootConverter: GroupConverter = root.asGroupConverter() } diff --git a/parquet/src/main/scala/magnolify/parquet/Predicate.scala b/parquet/src/main/scala/magnolify/parquet/Predicate.scala index 83d9dd4a..a1e0ce92 100644 --- a/parquet/src/main/scala/magnolify/parquet/Predicate.scala +++ b/parquet/src/main/scala/magnolify/parquet/Predicate.scala @@ -65,7 +65,7 @@ object Predicate { } def wrap[T](addFn: (PrimitiveConverter, T) => Unit): T => ScalaFieldT = { - lazy val converter = pf.newConverter + lazy val converter = pf.newConverter(pf.schema(CaseMapper.identity)) value => { addFn(converter.asPrimitiveConverter(), value) converter.get diff --git a/parquet/src/main/scala/magnolify/parquet/Schema.scala b/parquet/src/main/scala/magnolify/parquet/Schema.scala index e978b22d..e58fae85 100644 --- a/parquet/src/main/scala/magnolify/parquet/Schema.scala +++ b/parquet/src/main/scala/magnolify/parquet/Schema.scala @@ -95,6 +95,21 @@ private object Schema { builder.named(schema.getName) } + // Check if writer schema encodes arrays as a single repeated field inside of an optional or required group + private[parquet] def hasGroupedArray(writer: Type): Boolean = + !writer.isPrimitive && writer.asGroupType().getFields.asScala.exists { + case f if isGroupedArrayType(f) => true + case f if !f.isPrimitive => f.asGroupType().getFields.asScala.exists(hasGroupedArray) + case _ => false + } + + private def isGroupedArrayType(f: Type): Boolean = + !f.isPrimitive && + f.getLogicalTypeAnnotation == LogicalTypeAnnotation.listType() && { + val fields = f.asGroupType().getFields.asScala + fields.size == 1 && fields.head.isRepetition(Repetition.REPEATED) + } + def checkCompatibility(writer: Type, reader: Type): Unit = { def listFields(gt: GroupType) = s"[${gt.getFields.asScala.map(f => s"${f.getName}: ${f.getRepetition}").mkString(",")}]"