From 4c05290a51171c9a73260e30e09b8c3d39b2b0d9 Mon Sep 17 00:00:00 2001 From: jarvis Date: Mon, 6 Nov 2023 23:12:55 +0800 Subject: [PATCH] [Feature][avro-format] improve deserialization --- .../e2e/connector/kafka/KafkaIT.java | 3 + .../avro/AvroDeserializationSchema.java | 2 +- .../format/avro/AvroSerializationSchema.java | 3 +- .../format/avro/AvroToRowConverter.java | 31 ++--- .../format/avro/RowToAvroConverter.java | 100 ++-------------- ...SeaTunnelRowTypeToAvroSchemaConverter.java | 113 ++++++++++++++++++ .../format/avro/AvroConverterTest.java | 2 +- .../avro/AvroSerializationSchemaTest.java | 24 +++- 8 files changed, 154 insertions(+), 124 deletions(-) create mode 100644 seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/SeaTunnelRowTypeToAvroSchemaConverter.java diff --git a/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-kafka-e2e/src/test/java/org/apache/seatunnel/e2e/connector/kafka/KafkaIT.java b/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-kafka-e2e/src/test/java/org/apache/seatunnel/e2e/connector/kafka/KafkaIT.java index b95c90e3fe8..0bf222bf87d 100644 --- a/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-kafka-e2e/src/test/java/org/apache/seatunnel/e2e/connector/kafka/KafkaIT.java +++ b/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-kafka-e2e/src/test/java/org/apache/seatunnel/e2e/connector/kafka/KafkaIT.java @@ -34,6 +34,7 @@ import org.apache.seatunnel.e2e.common.TestResource; import org.apache.seatunnel.e2e.common.TestSuiteBase; import org.apache.seatunnel.e2e.common.container.TestContainer; +import org.apache.seatunnel.e2e.common.container.TestContainerId; import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer; import org.apache.seatunnel.format.text.TextSerializationSchema; @@ -293,6 +294,7 @@ public void testSourceKafkaStartConfig(TestContainer container) } @TestTemplate + @DisabledOnContainer(TestContainerId.SPARK_2_4) public void testFakeSourceToKafkaAvroFormat(TestContainer container) throws IOException, InterruptedException { Container.ExecResult execResult = @@ -301,6 +303,7 @@ public void testFakeSourceToKafkaAvroFormat(TestContainer container) } @TestTemplate + @DisabledOnContainer(TestContainerId.SPARK_2_4) public void testKafkaAvroToConsole(TestContainer container) throws IOException, InterruptedException { DefaultSeaTunnelRowSerializer serializer = diff --git a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/AvroDeserializationSchema.java b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/AvroDeserializationSchema.java index 3d2c3fba595..b682a8e6431 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/AvroDeserializationSchema.java +++ b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/AvroDeserializationSchema.java @@ -37,7 +37,7 @@ public class AvroDeserializationSchema implements DeserializationSchema reader = null; + private Schema schema; - public AvroToRowConverter() {} + public AvroToRowConverter(SeaTunnelRowType rowType) { + schema = SeaTunnelRowTypeToAvroSchemaConverter.buildAvroSchemaWithRowType(rowType); + } public DatumReader getReader() { if (reader == null) { @@ -59,12 +54,12 @@ public DatumReader getReader() { } private DatumReader createReader() { - GenericDatumReader datumReader = new GenericDatumReader<>(); + GenericDatumReader datumReader = new GenericDatumReader<>(schema, schema); datumReader.getData().addLogicalTypeConversion(new Conversions.DecimalConversion()); datumReader.getData().addLogicalTypeConversion(new TimeConversions.DateConversion()); datumReader .getData() - .addLogicalTypeConversion(new TimeConversions.TimestampMillisConversion()); + .addLogicalTypeConversion(new TimeConversions.LocalTimestampMillisConversion()); return datumReader; } @@ -98,6 +93,9 @@ private Object convertField(SeaTunnelDataType dataType, Schema.Field field, O case DOUBLE: case NULL: case BYTES: + case DATE: + case DECIMAL: + case TIMESTAMP: return val; case TINYINT: Class typeClass = dataType.getTypeClass(); @@ -110,17 +108,6 @@ private Object convertField(SeaTunnelDataType dataType, Schema.Field field, O BasicType basicType = ((ArrayType) dataType).getElementType(); List list = (List) val; return convertArray(list, basicType); - case DECIMAL: - LogicalTypes.Decimal decimal = - (LogicalTypes.Decimal) field.schema().getLogicalType(); - ByteBuffer buffer = (ByteBuffer) val; - byte[] bytes = buffer.array(); - return new BigDecimal(new BigInteger(bytes), decimal.getScale()); - case DATE: - return LocalDate.ofEpochDay((Long) val); - case TIMESTAMP: - return LocalDateTime.ofInstant( - Instant.ofEpochMilli((Long) val), ZoneId.systemDefault()); case ROW: SeaTunnelRowType subRow = (SeaTunnelRowType) dataType; return converter((GenericRecord) val, subRow); diff --git a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java index b1bd3360177..f8f0652a26c 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java +++ b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/RowToAvroConverter.java @@ -20,8 +20,6 @@ import org.apache.seatunnel.api.table.type.ArrayType; import org.apache.seatunnel.api.table.type.BasicType; -import org.apache.seatunnel.api.table.type.DecimalType; -import org.apache.seatunnel.api.table.type.MapType; import org.apache.seatunnel.api.table.type.SeaTunnelDataType; import org.apache.seatunnel.api.table.type.SeaTunnelRow; import org.apache.seatunnel.api.table.type.SeaTunnelRowType; @@ -29,7 +27,6 @@ import org.apache.seatunnel.format.avro.exception.SeaTunnelAvroFormatException; import org.apache.avro.Conversions; -import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; import org.apache.avro.data.TimeConversions; import org.apache.avro.generic.GenericDatumWriter; @@ -38,11 +35,7 @@ import org.apache.avro.io.DatumWriter; import java.io.Serializable; -import java.math.BigDecimal; import java.nio.ByteBuffer; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.ZoneId; import java.util.ArrayList; import java.util.List; @@ -55,7 +48,7 @@ public class RowToAvroConverter implements Serializable { private final DatumWriter writer; public RowToAvroConverter(SeaTunnelRowType rowType) { - this.schema = buildAvroSchemaWithRowType(rowType); + this.schema = SeaTunnelRowTypeToAvroSchemaConverter.buildAvroSchemaWithRowType(rowType); this.rowType = rowType; this.writer = createWriter(); } @@ -66,7 +59,7 @@ private DatumWriter createWriter() { datumWriter.getData().addLogicalTypeConversion(new TimeConversions.DateConversion()); datumWriter .getData() - .addLogicalTypeConversion(new TimeConversions.TimestampMillisConversion()); + .addLogicalTypeConversion(new TimeConversions.LocalTimestampMillisConversion()); return datumWriter; } @@ -89,80 +82,6 @@ public GenericRecord convertRowToGenericRecord(SeaTunnelRow element) { return builder.build(); } - private Schema buildAvroSchemaWithRowType(SeaTunnelRowType seaTunnelRowType) { - List fields = new ArrayList<>(); - SeaTunnelDataType[] fieldTypes = seaTunnelRowType.getFieldTypes(); - String[] fieldNames = seaTunnelRowType.getFieldNames(); - for (int i = 0; i < fieldNames.length; i++) { - fields.add(generateField(fieldNames[i], fieldTypes[i])); - } - return Schema.createRecord("SeaTunnelRecord", null, null, false, fields); - } - - private Schema.Field generateField(String fieldName, SeaTunnelDataType seaTunnelDataType) { - return new Schema.Field( - fieldName, - seaTunnelDataType2AvroDataType(fieldName, seaTunnelDataType), - null, - null); - } - - private Schema seaTunnelDataType2AvroDataType( - String fieldName, SeaTunnelDataType seaTunnelDataType) { - - switch (seaTunnelDataType.getSqlType()) { - case STRING: - return Schema.create(Schema.Type.STRING); - case BYTES: - return Schema.create(Schema.Type.BYTES); - case TINYINT: - case SMALLINT: - case INT: - return Schema.create(Schema.Type.INT); - case BIGINT: - return Schema.create(Schema.Type.LONG); - case FLOAT: - return Schema.create(Schema.Type.FLOAT); - case DOUBLE: - return Schema.create(Schema.Type.DOUBLE); - case BOOLEAN: - return Schema.create(Schema.Type.BOOLEAN); - case MAP: - SeaTunnelDataType valueType = ((MapType) seaTunnelDataType).getValueType(); - return Schema.createMap(seaTunnelDataType2AvroDataType(fieldName, valueType)); - case ARRAY: - BasicType elementType = ((ArrayType) seaTunnelDataType).getElementType(); - return Schema.createArray(seaTunnelDataType2AvroDataType(fieldName, elementType)); - case ROW: - SeaTunnelDataType[] fieldTypes = - ((SeaTunnelRowType) seaTunnelDataType).getFieldTypes(); - String[] fieldNames = ((SeaTunnelRowType) seaTunnelDataType).getFieldNames(); - List subField = new ArrayList<>(); - for (int i = 0; i < fieldNames.length; i++) { - subField.add(generateField(fieldNames[i], fieldTypes[i])); - } - return Schema.createRecord(fieldName, null, null, false, subField); - case DECIMAL: - int precision = ((DecimalType) seaTunnelDataType).getPrecision(); - int scale = ((DecimalType) seaTunnelDataType).getScale(); - LogicalTypes.Decimal decimal = LogicalTypes.decimal(precision, scale); - return decimal.addToSchema(Schema.create(Schema.Type.BYTES)); - case TIMESTAMP: - return LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); - case DATE: - return LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT)); - case NULL: - return Schema.create(Schema.Type.NULL); - default: - String errorMsg = - String.format( - "SeaTunnel avro format is not supported for this data type [%s]", - seaTunnelDataType.getSqlType()); - throw new SeaTunnelAvroFormatException( - AvroFormatErrorCode.UNSUPPORTED_DATA_TYPE, errorMsg); - } - } - private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType) { if (data == null) { return null; @@ -176,6 +95,9 @@ private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType case DOUBLE: case BOOLEAN: case MAP: + case DECIMAL: + case DATE: + case TIMESTAMP: return data; case TINYINT: Class typeClass = seaTunnelDataType.getTypeClass(); @@ -186,12 +108,6 @@ private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType } } return data; - case DECIMAL: - BigDecimal decimal = (BigDecimal) data; - return ByteBuffer.wrap(decimal.unscaledValue().toByteArray()); - case DATE: - LocalDate localDate = (LocalDate) data; - return localDate.toEpochDay(); case BYTES: return ByteBuffer.wrap((byte[]) data); case ARRAY: @@ -211,7 +127,8 @@ private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType ((SeaTunnelRowType) seaTunnelDataType).getFieldTypes(); String[] fieldNames = ((SeaTunnelRowType) seaTunnelDataType).getFieldNames(); Schema recordSchema = - buildAvroSchemaWithRowType((SeaTunnelRowType) seaTunnelDataType); + SeaTunnelRowTypeToAvroSchemaConverter.buildAvroSchemaWithRowType( + (SeaTunnelRowType) seaTunnelDataType); GenericRecordBuilder recordBuilder = new GenericRecordBuilder(recordSchema); for (int i = 0; i < fieldNames.length; i++) { recordBuilder.set( @@ -219,9 +136,6 @@ private Object resolveObject(Object data, SeaTunnelDataType seaTunnelDataType resolveObject(seaTunnelRow.getField(i), fieldTypes[i])); } return recordBuilder.build(); - case TIMESTAMP: - LocalDateTime dateTime = (LocalDateTime) data; - return (dateTime).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(); default: String errorMsg = String.format( diff --git a/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/SeaTunnelRowTypeToAvroSchemaConverter.java b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/SeaTunnelRowTypeToAvroSchemaConverter.java new file mode 100644 index 00000000000..195ff8004c5 --- /dev/null +++ b/seatunnel-formats/seatunnel-format-avro/src/main/java/org/apache/seatunnel/format/avro/SeaTunnelRowTypeToAvroSchemaConverter.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.seatunnel.format.avro; + +import org.apache.seatunnel.api.table.type.ArrayType; +import org.apache.seatunnel.api.table.type.BasicType; +import org.apache.seatunnel.api.table.type.DecimalType; +import org.apache.seatunnel.api.table.type.MapType; +import org.apache.seatunnel.api.table.type.SeaTunnelDataType; +import org.apache.seatunnel.api.table.type.SeaTunnelRowType; +import org.apache.seatunnel.format.avro.exception.AvroFormatErrorCode; +import org.apache.seatunnel.format.avro.exception.SeaTunnelAvroFormatException; + +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; + +import java.util.ArrayList; +import java.util.List; + +public class SeaTunnelRowTypeToAvroSchemaConverter { + + public static Schema buildAvroSchemaWithRowType(SeaTunnelRowType seaTunnelRowType) { + List fields = new ArrayList<>(); + SeaTunnelDataType[] fieldTypes = seaTunnelRowType.getFieldTypes(); + String[] fieldNames = seaTunnelRowType.getFieldNames(); + for (int i = 0; i < fieldNames.length; i++) { + fields.add(generateField(fieldNames[i], fieldTypes[i])); + } + return Schema.createRecord("SeaTunnelRecord", null, null, false, fields); + } + + private static Schema.Field generateField( + String fieldName, SeaTunnelDataType seaTunnelDataType) { + return new Schema.Field( + fieldName, + seaTunnelDataType2AvroDataType(fieldName, seaTunnelDataType), + null, + null); + } + + private static Schema seaTunnelDataType2AvroDataType( + String fieldName, SeaTunnelDataType seaTunnelDataType) { + + switch (seaTunnelDataType.getSqlType()) { + case STRING: + return Schema.create(Schema.Type.STRING); + case BYTES: + return Schema.create(Schema.Type.BYTES); + case TINYINT: + case SMALLINT: + case INT: + return Schema.create(Schema.Type.INT); + case BIGINT: + return Schema.create(Schema.Type.LONG); + case FLOAT: + return Schema.create(Schema.Type.FLOAT); + case DOUBLE: + return Schema.create(Schema.Type.DOUBLE); + case BOOLEAN: + return Schema.create(Schema.Type.BOOLEAN); + case MAP: + SeaTunnelDataType valueType = ((MapType) seaTunnelDataType).getValueType(); + return Schema.createMap(seaTunnelDataType2AvroDataType(fieldName, valueType)); + case ARRAY: + BasicType elementType = ((ArrayType) seaTunnelDataType).getElementType(); + return Schema.createArray(seaTunnelDataType2AvroDataType(fieldName, elementType)); + case ROW: + SeaTunnelDataType[] fieldTypes = + ((SeaTunnelRowType) seaTunnelDataType).getFieldTypes(); + String[] fieldNames = ((SeaTunnelRowType) seaTunnelDataType).getFieldNames(); + List subField = new ArrayList<>(); + for (int i = 0; i < fieldNames.length; i++) { + subField.add(generateField(fieldNames[i], fieldTypes[i])); + } + return Schema.createRecord(fieldName, null, null, false, subField); + case DECIMAL: + int precision = ((DecimalType) seaTunnelDataType).getPrecision(); + int scale = ((DecimalType) seaTunnelDataType).getScale(); + LogicalTypes.Decimal decimal = LogicalTypes.decimal(precision, scale); + return decimal.addToSchema(Schema.create(Schema.Type.BYTES)); + case TIMESTAMP: + return LogicalTypes.localTimestampMillis() + .addToSchema(Schema.create(Schema.Type.LONG)); + case DATE: + return LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT)); + case NULL: + return Schema.create(Schema.Type.NULL); + default: + String errorMsg = + String.format( + "SeaTunnel avro format is not supported for this data type [%s]", + seaTunnelDataType.getSqlType()); + throw new SeaTunnelAvroFormatException( + AvroFormatErrorCode.UNSUPPORTED_DATA_TYPE, errorMsg); + } + } +} diff --git a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java index 66d847dcc3a..fb45a0b5377 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java +++ b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroConverterTest.java @@ -166,7 +166,7 @@ public void testConverter() { RowToAvroConverter rowToAvroConverter = new RowToAvroConverter(rowType); GenericRecord record = rowToAvroConverter.convertRowToGenericRecord(seaTunnelRow); - AvroToRowConverter avroToRowConverter = new AvroToRowConverter(); + AvroToRowConverter avroToRowConverter = new AvroToRowConverter(rowType); SeaTunnelRow converterRow = avroToRowConverter.converter(record, rowType); Assertions.assertEquals(converterRow, seaTunnelRow); diff --git a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java index 3291317bf8b..5f505e1ba6b 100644 --- a/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java +++ b/seatunnel-formats/seatunnel-format-avro/src/test/java/org/apache/seatunnel/format/avro/AvroSerializationSchemaTest.java @@ -26,8 +26,10 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRow; import org.apache.seatunnel.api.table.type.SeaTunnelRowType; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.math.BigDecimal; import java.time.LocalDate; import java.time.LocalDateTime; @@ -36,6 +38,10 @@ class AvroSerializationSchemaTest { + private LocalDate localDate = LocalDate.of(2023, 1, 1); + private BigDecimal bigDecimal = new BigDecimal("61592600349703735722.724745739637773662"); + private LocalDateTime localDateTime = LocalDateTime.of(2023, 1, 1, 6, 30, 40); + private SeaTunnelRow buildSeaTunnelRow() { SeaTunnelRow subSeaTunnelRow = new SeaTunnelRow(14); Map map = new HashMap(); @@ -43,10 +49,6 @@ private SeaTunnelRow buildSeaTunnelRow() { map.put("k2", "v2"); String[] strArray = new String[] {"l1", "l2"}; byte byteVal = 100; - LocalDate localDate = LocalDate.of(2023, 1, 1); - BigDecimal bigDecimal = new BigDecimal("61592600349703735722.724745739637773662"); - LocalDateTime localDateTime = LocalDateTime.of(2023, 1, 1, 6, 30, 40); - subSeaTunnelRow.setField(0, map); subSeaTunnelRow.setField(1, strArray); subSeaTunnelRow.setField(2, "strVal"); @@ -155,11 +157,21 @@ private SeaTunnelRowType buildSeaTunnelRowType() { } @Test - public void testSerialization() { + public void testSerialization() throws IOException { SeaTunnelRowType rowType = buildSeaTunnelRowType(); SeaTunnelRow seaTunnelRow = buildSeaTunnelRow(); AvroSerializationSchema serializationSchema = new AvroSerializationSchema(rowType); byte[] serialize = serializationSchema.serialize(seaTunnelRow); - assert serialize.length > 0; + AvroDeserializationSchema deserializationSchema = new AvroDeserializationSchema(rowType); + SeaTunnelRow deserialize = deserializationSchema.deserialize(serialize); + String[] strArray1 = (String[]) seaTunnelRow.getField(1); + String[] strArray2 = (String[]) deserialize.getField(1); + Assertions.assertArrayEquals(strArray1, strArray2); + SeaTunnelRow subRow = (SeaTunnelRow) deserialize.getField(14); + Assertions.assertEquals((double) subRow.getField(9), 123.456); + BigDecimal bigDecimal1 = (BigDecimal) subRow.getField(12); + Assertions.assertEquals(bigDecimal1.compareTo(bigDecimal), 0); + LocalDateTime localDateTime1 = (LocalDateTime) subRow.getField(13); + Assertions.assertEquals(localDateTime1.compareTo(localDateTime), 0); } }