Skip to content

Commit

Permalink
Derive AvroCompat automatically on read
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed Jul 9, 2024
1 parent 7f65cb4 commit c24b4bf
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 34 deletions.
61 changes: 34 additions & 27 deletions parquet/src/main/scala/magnolify/parquet/ParquetField.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ sealed trait ParquetField[T] extends Serializable {
protected final def nonEmpty(v: T): Boolean = !isEmpty(v)

def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit
def newConverter: TypeConverter[T]
def newConverter(writerSchema: Type): TypeConverter[T]

protected def writeGroup(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = {
if (isGroup) {
Expand Down Expand Up @@ -83,8 +83,9 @@ object ParquetField {
override protected def isEmpty(v: T): Boolean = tc.isEmpty(p.dereference(v))
override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit =
tc.writeGroup(c, p.dereference(v))(cm)
override def newConverter: TypeConverter[T] = {
val buffered = tc.newConverter
override def newConverter(writerSchema: Type): TypeConverter[T] = {
val buffered = tc
.newConverter(writerSchema)
.asInstanceOf[TypeConverter.Buffered[p.PType]]
new TypeConverter.Delegate[p.PType, T](buffered) {
override def get: T = inner.get(b => caseClass.construct(_ => b.head))
Expand Down Expand Up @@ -139,9 +140,10 @@ object ParquetField {
}
}

override def newConverter: TypeConverter[T] =
override def newConverter(writerSchema: Type): TypeConverter[T] =
new GroupConverter with TypeConverter.Buffered[T] {
private val fieldConverters = caseClass.parameters.map(_.typeclass.newConverter)
private val fieldConverters =
caseClass.parameters.map(_.typeclass.newConverter(writerSchema))

override def isPrimitive: Boolean = false

Expand Down Expand Up @@ -191,8 +193,9 @@ object ParquetField {
new Primitive[U] {
override def buildSchema(cm: CaseMapper): Type = pf.schema(cm)
override def write(c: RecordConsumer, v: U)(cm: CaseMapper): Unit = pf.write(c, g(v))(cm)
override def newConverter: TypeConverter[U] =
pf.newConverter.asInstanceOf[TypeConverter.Primitive[T]].map(f)
override def newConverter(writerSchema: Type): TypeConverter[U] =
pf.newConverter(writerSchema).asInstanceOf[TypeConverter.Primitive[T]].map(f)

override type ParquetT = pf.ParquetT
}
}
Expand All @@ -215,7 +218,7 @@ object ParquetField {
new Primitive[T] {
override def buildSchema(cm: CaseMapper): Type = Schema.primitive(ptn, lta)
override def write(c: RecordConsumer, v: T)(cm: CaseMapper): Unit = f(c)(v)
override def newConverter: TypeConverter[T] = g
override def newConverter(writerSchema: Type): TypeConverter[T] = g
override type ParquetT = UnderlyingT
}

Expand Down Expand Up @@ -291,8 +294,9 @@ object ParquetField {
override def write(c: RecordConsumer, v: Option[T])(cm: CaseMapper): Unit =
v.foreach(t.writeGroup(c, _)(cm))

override def newConverter: TypeConverter[Option[T]] = {
val buffered = t.newConverter
override def newConverter(writerSchema: Type): TypeConverter[Option[T]] = {
val buffered = t
.newConverter(writerSchema)
.asInstanceOf[TypeConverter.Buffered[T]]
.withRepetition(Repetition.OPTIONAL)
new TypeConverter.Delegate[T, Option[T]](buffered) {
Expand Down Expand Up @@ -339,15 +343,16 @@ object ParquetField {
v.foreach(t.writeGroup(c, _)(cm))
}

override def newConverter: TypeConverter[C[T]] = {
val buffered = t.newConverter
override def newConverter(writerSchema: Type): TypeConverter[C[T]] = {
val buffered = t
.newConverter(writerSchema)
.asInstanceOf[TypeConverter.Buffered[T]]
.withRepetition(Repetition.REPEATED)
val arrayConverter = new TypeConverter.Delegate[T, C[T]](buffered) {
override def get: C[T] = inner.get(fc.fromSpecific)
}

if (hasAvroArray) {
if (Schema.hasGroupedArray(writerSchema)) {
new GroupConverter with TypeConverter.Buffered[C[T]] {
override def getConverter(fieldIndex: Int): Converter = {
require(fieldIndex == 0, "Avro array field index != 0")
Expand Down Expand Up @@ -421,10 +426,10 @@ object ParquetField {
}
}

override def newConverter: TypeConverter[Map[K, V]] = {
override def newConverter(writerSchema: Type): TypeConverter[Map[K, V]] = {
val kvConverter = new GroupConverter with TypeConverter.Buffered[(K, V)] {
private val keyConverter = pfKey.newConverter
private val valueConverter = pfValue.newConverter
private val keyConverter = pfKey.newConverter(writerSchema)
private val valueConverter = pfValue.newConverter(writerSchema)
private val fieldConverters = Array(keyConverter, valueConverter)

override def isPrimitive: Boolean = false
Expand Down Expand Up @@ -466,8 +471,8 @@ object ParquetField {
def apply[U](f: T => U)(g: U => T)(implicit pf: Primitive[T]): Primitive[U] = new Primitive[U] {
override def buildSchema(cm: CaseMapper): Type = Schema.setLogicalType(pf.schema(cm), lta)
override def write(c: RecordConsumer, v: U)(cm: CaseMapper): Unit = pf.write(c, g(v))(cm)
override def newConverter: TypeConverter[U] =
pf.newConverter.asInstanceOf[TypeConverter.Primitive[T]].map(f)
override def newConverter(writerSchema: Type): TypeConverter[U] =
pf.newConverter(writerSchema).asInstanceOf[TypeConverter.Primitive[T]].map(f)

override type ParquetT = pf.ParquetT
}
Expand Down Expand Up @@ -509,9 +514,10 @@ object ParquetField {
override def write(c: RecordConsumer, v: BigDecimal)(cm: CaseMapper): Unit =
c.addBinary(Binary.fromConstantByteArray(Decimal.toFixed(v, precision, scale, length)))

override def newConverter: TypeConverter[BigDecimal] = TypeConverter.newByteArray.map { ba =>
Decimal.fromBytes(ba, precision, scale)
}
override def newConverter(writerSchema: Type): TypeConverter[BigDecimal] =
TypeConverter.newByteArray.map { ba =>
Decimal.fromBytes(ba, precision, scale)
}

override type ParquetT = Binary
}
Expand Down Expand Up @@ -544,12 +550,13 @@ object ParquetField {
)
)

override def newConverter: TypeConverter[UUID] = TypeConverter.newByteArray.map { ba =>
val bb = ByteBuffer.wrap(ba)
val h = bb.getLong
val l = bb.getLong
new UUID(h, l)
}
override def newConverter(writerSchema: Type): TypeConverter[UUID] =
TypeConverter.newByteArray.map { ba =>
val bb = ByteBuffer.wrap(ba)
val h = bb.getLong
val l = bb.getLong
new UUID(h, l)
}

override type ParquetT = Binary
}
Expand Down
27 changes: 21 additions & 6 deletions parquet/src/main/scala/magnolify/parquet/ParquetType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.parquet.hadoop.{
}
import org.apache.parquet.io.api._
import org.apache.parquet.io.{InputFile, OutputFile}
import org.apache.parquet.schema.MessageType
import org.apache.parquet.schema.{MessageType, Type}
import org.slf4j.LoggerFactory
import org.typelevel.scalaccompat.annotation.nowarn

Expand Down Expand Up @@ -73,7 +73,7 @@ sealed trait ParquetType[T] extends Serializable {
def writeBuilder(file: OutputFile): WriteBuilder[T] = new WriteBuilder(file, writeSupport)

def write(c: RecordConsumer, v: T): Unit = ()
def newConverter: TypeConverter[T] = null
def newConverter(writerSchema: Type): TypeConverter[T] = null
}

object ParquetType {
Expand All @@ -97,8 +97,10 @@ object ParquetType {

override val avroCompat: Boolean =
pa == ParquetArray.AvroCompat.avroCompat || f.hasAvroArray

override def write(c: RecordConsumer, v: T): Unit = r.write(c, v)(cm)
override def newConverter: TypeConverter[T] = r.newConverter
override def newConverter(writerSchema: Type): TypeConverter[T] =
r.newConverter(writerSchema)
}
case _ =>
throw new IllegalArgumentException(s"ParquetType can only be created from Record. Got $f")
Expand Down Expand Up @@ -151,9 +153,22 @@ object ParquetType {
)
}

val requestedSchema = Schema.message(parquetType.schema)
val requestedSchema = {
val s = Schema.message(parquetType.schema)
// If reading Avro, roundtrip schema using parquet-avro converter to ensure array compatibility;
// magnolify-parquet does not automatically wrap repeated fields into a group like parquet-avro does
if (isAvroFile) {
val converter = new AvroSchemaConverter()
converter.convert(converter.convert(s))
} else {
s
}
}
Schema.checkCompatibility(context.getFileSchema, requestedSchema)
new hadoop.ReadSupport.ReadContext(requestedSchema, java.util.Collections.emptyMap())
new hadoop.ReadSupport.ReadContext(
requestedSchema,
java.util.Collections.emptyMap()
)
}

override def prepareForRead(
Expand All @@ -163,7 +178,7 @@ object ParquetType {
readContext: hadoop.ReadSupport.ReadContext
): RecordMaterializer[T] =
new RecordMaterializer[T] {
private val root = parquetType.newConverter
private val root = parquetType.newConverter(fileSchema)
override def getCurrentRecord: T = root.get
override def getRootConverter: GroupConverter = root.asGroupConverter()
}
Expand Down
2 changes: 1 addition & 1 deletion parquet/src/main/scala/magnolify/parquet/Predicate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object Predicate {
}

def wrap[T](addFn: (PrimitiveConverter, T) => Unit): T => ScalaFieldT = {
lazy val converter = pf.newConverter
lazy val converter = pf.newConverter(pf.schema(CaseMapper.identity))
value => {
addFn(converter.asPrimitiveConverter(), value)
converter.get
Expand Down
15 changes: 15 additions & 0 deletions parquet/src/main/scala/magnolify/parquet/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,21 @@ private object Schema {
builder.named(schema.getName)
}

// Check if writer schema encodes arrays as a single repeated field inside of an optional or required group
private[parquet] def hasGroupedArray(writer: Type): Boolean =
!writer.isPrimitive && writer.asGroupType().getFields.asScala.exists {
case f if isGroupedArrayType(f) => true
case f if !f.isPrimitive => f.asGroupType().getFields.asScala.exists(hasGroupedArray)
case _ => false
}

private def isGroupedArrayType(f: Type): Boolean =
!f.isPrimitive &&
f.getLogicalTypeAnnotation == LogicalTypeAnnotation.listType() && {
val fields = f.asGroupType().getFields.asScala
fields.size == 1 && fields.head.isRepetition(Repetition.REPEATED)
}

def checkCompatibility(writer: Type, reader: Type): Unit = {
def listFields(gt: GroupType) =
s"[${gt.getFields.asScala.map(f => s"${f.getName}: ${f.getRepetition}").mkString(",")}]"
Expand Down

0 comments on commit c24b4bf

Please sign in to comment.