Skip to content

Commit

Permalink
[Feature][avro-format] improve deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
liunaijie committed Nov 6, 2023
1 parent 44ddb62 commit 4c05290
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.TestSuiteBase;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.container.TestContainerId;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;
import org.apache.seatunnel.format.text.TextSerializationSchema;

Expand Down Expand Up @@ -293,6 +294,7 @@ public void testSourceKafkaStartConfig(TestContainer container)
}

@TestTemplate
@DisabledOnContainer(TestContainerId.SPARK_2_4)
public void testFakeSourceToKafkaAvroFormat(TestContainer container)
throws IOException, InterruptedException {
Container.ExecResult execResult =
Expand All @@ -301,6 +303,7 @@ public void testFakeSourceToKafkaAvroFormat(TestContainer container)
}

@TestTemplate
@DisabledOnContainer(TestContainerId.SPARK_2_4)
public void testKafkaAvroToConsole(TestContainer container)
throws IOException, InterruptedException {
DefaultSeaTunnelRowSerializer serializer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class AvroDeserializationSchema implements DeserializationSchema<SeaTunne

public AvroDeserializationSchema(SeaTunnelRowType rowType) {
this.rowType = rowType;
this.converter = new AvroToRowConverter();
this.converter = new AvroToRowConverter(rowType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public byte[] serialize(SeaTunnelRow element) {
return out.toByteArray();
} catch (IOException e) {
throw new SeaTunnelAvroFormatException(
AvroFormatErrorCode.SERIALIZATION_ERROR, e.toString());
AvroFormatErrorCode.SERIALIZATION_ERROR,
"Serialization error on record : " + element);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,25 @@
import org.apache.seatunnel.format.avro.exception.SeaTunnelAvroFormatException;

import org.apache.avro.Conversions;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
import org.apache.avro.data.TimeConversions;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.io.DatumReader;

import java.io.Serializable;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.List;

public class AvroToRowConverter implements Serializable {

private static final long serialVersionUID = 8177020083886379563L;

private DatumReader<GenericRecord> reader = null;
private Schema schema;

public AvroToRowConverter() {}
public AvroToRowConverter(SeaTunnelRowType rowType) {
schema = SeaTunnelRowTypeToAvroSchemaConverter.buildAvroSchemaWithRowType(rowType);
}

public DatumReader<GenericRecord> getReader() {
if (reader == null) {
Expand All @@ -59,12 +54,12 @@ public DatumReader<GenericRecord> getReader() {
}

private DatumReader<GenericRecord> createReader() {
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>();
GenericDatumReader<GenericRecord> datumReader = new GenericDatumReader<>(schema, schema);
datumReader.getData().addLogicalTypeConversion(new Conversions.DecimalConversion());
datumReader.getData().addLogicalTypeConversion(new TimeConversions.DateConversion());
datumReader
.getData()
.addLogicalTypeConversion(new TimeConversions.TimestampMillisConversion());
.addLogicalTypeConversion(new TimeConversions.LocalTimestampMillisConversion());
return datumReader;
}

Expand Down Expand Up @@ -98,6 +93,9 @@ private Object convertField(SeaTunnelDataType<?> dataType, Schema.Field field, O
case DOUBLE:
case NULL:
case BYTES:
case DATE:
case DECIMAL:
case TIMESTAMP:
return val;
case TINYINT:
Class<?> typeClass = dataType.getTypeClass();
Expand All @@ -110,17 +108,6 @@ private Object convertField(SeaTunnelDataType<?> dataType, Schema.Field field, O
BasicType<?> basicType = ((ArrayType<?, ?>) dataType).getElementType();
List<Object> list = (List<Object>) val;
return convertArray(list, basicType);
case DECIMAL:
LogicalTypes.Decimal decimal =
(LogicalTypes.Decimal) field.schema().getLogicalType();
ByteBuffer buffer = (ByteBuffer) val;
byte[] bytes = buffer.array();
return new BigDecimal(new BigInteger(bytes), decimal.getScale());
case DATE:
return LocalDate.ofEpochDay((Long) val);
case TIMESTAMP:
return LocalDateTime.ofInstant(
Instant.ofEpochMilli((Long) val), ZoneId.systemDefault());
case ROW:
SeaTunnelRowType subRow = (SeaTunnelRowType) dataType;
return converter((GenericRecord) val, subRow);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@

import org.apache.seatunnel.api.table.type.ArrayType;
import org.apache.seatunnel.api.table.type.BasicType;
import org.apache.seatunnel.api.table.type.DecimalType;
import org.apache.seatunnel.api.table.type.MapType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.format.avro.exception.AvroFormatErrorCode;
import org.apache.seatunnel.format.avro.exception.SeaTunnelAvroFormatException;

import org.apache.avro.Conversions;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
import org.apache.avro.data.TimeConversions;
import org.apache.avro.generic.GenericDatumWriter;
Expand All @@ -38,11 +35,7 @@
import org.apache.avro.io.DatumWriter;

import java.io.Serializable;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -55,7 +48,7 @@ public class RowToAvroConverter implements Serializable {
private final DatumWriter<GenericRecord> writer;

public RowToAvroConverter(SeaTunnelRowType rowType) {
this.schema = buildAvroSchemaWithRowType(rowType);
this.schema = SeaTunnelRowTypeToAvroSchemaConverter.buildAvroSchemaWithRowType(rowType);
this.rowType = rowType;
this.writer = createWriter();
}
Expand All @@ -66,7 +59,7 @@ private DatumWriter<GenericRecord> createWriter() {
datumWriter.getData().addLogicalTypeConversion(new TimeConversions.DateConversion());
datumWriter
.getData()
.addLogicalTypeConversion(new TimeConversions.TimestampMillisConversion());
.addLogicalTypeConversion(new TimeConversions.LocalTimestampMillisConversion());
return datumWriter;
}

Expand All @@ -89,80 +82,6 @@ public GenericRecord convertRowToGenericRecord(SeaTunnelRow element) {
return builder.build();
}

private Schema buildAvroSchemaWithRowType(SeaTunnelRowType seaTunnelRowType) {
List<Schema.Field> fields = new ArrayList<>();
SeaTunnelDataType<?>[] fieldTypes = seaTunnelRowType.getFieldTypes();
String[] fieldNames = seaTunnelRowType.getFieldNames();
for (int i = 0; i < fieldNames.length; i++) {
fields.add(generateField(fieldNames[i], fieldTypes[i]));
}
return Schema.createRecord("SeaTunnelRecord", null, null, false, fields);
}

private Schema.Field generateField(String fieldName, SeaTunnelDataType<?> seaTunnelDataType) {
return new Schema.Field(
fieldName,
seaTunnelDataType2AvroDataType(fieldName, seaTunnelDataType),
null,
null);
}

private Schema seaTunnelDataType2AvroDataType(
String fieldName, SeaTunnelDataType<?> seaTunnelDataType) {

switch (seaTunnelDataType.getSqlType()) {
case STRING:
return Schema.create(Schema.Type.STRING);
case BYTES:
return Schema.create(Schema.Type.BYTES);
case TINYINT:
case SMALLINT:
case INT:
return Schema.create(Schema.Type.INT);
case BIGINT:
return Schema.create(Schema.Type.LONG);
case FLOAT:
return Schema.create(Schema.Type.FLOAT);
case DOUBLE:
return Schema.create(Schema.Type.DOUBLE);
case BOOLEAN:
return Schema.create(Schema.Type.BOOLEAN);
case MAP:
SeaTunnelDataType<?> valueType = ((MapType<?, ?>) seaTunnelDataType).getValueType();
return Schema.createMap(seaTunnelDataType2AvroDataType(fieldName, valueType));
case ARRAY:
BasicType<?> elementType = ((ArrayType<?, ?>) seaTunnelDataType).getElementType();
return Schema.createArray(seaTunnelDataType2AvroDataType(fieldName, elementType));
case ROW:
SeaTunnelDataType<?>[] fieldTypes =
((SeaTunnelRowType) seaTunnelDataType).getFieldTypes();
String[] fieldNames = ((SeaTunnelRowType) seaTunnelDataType).getFieldNames();
List<Schema.Field> subField = new ArrayList<>();
for (int i = 0; i < fieldNames.length; i++) {
subField.add(generateField(fieldNames[i], fieldTypes[i]));
}
return Schema.createRecord(fieldName, null, null, false, subField);
case DECIMAL:
int precision = ((DecimalType) seaTunnelDataType).getPrecision();
int scale = ((DecimalType) seaTunnelDataType).getScale();
LogicalTypes.Decimal decimal = LogicalTypes.decimal(precision, scale);
return decimal.addToSchema(Schema.create(Schema.Type.BYTES));
case TIMESTAMP:
return LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG));
case DATE:
return LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT));
case NULL:
return Schema.create(Schema.Type.NULL);
default:
String errorMsg =
String.format(
"SeaTunnel avro format is not supported for this data type [%s]",
seaTunnelDataType.getSqlType());
throw new SeaTunnelAvroFormatException(
AvroFormatErrorCode.UNSUPPORTED_DATA_TYPE, errorMsg);
}
}

private Object resolveObject(Object data, SeaTunnelDataType<?> seaTunnelDataType) {
if (data == null) {
return null;
Expand All @@ -176,6 +95,9 @@ private Object resolveObject(Object data, SeaTunnelDataType<?> seaTunnelDataType
case DOUBLE:
case BOOLEAN:
case MAP:
case DECIMAL:
case DATE:
case TIMESTAMP:
return data;
case TINYINT:
Class<?> typeClass = seaTunnelDataType.getTypeClass();
Expand All @@ -186,12 +108,6 @@ private Object resolveObject(Object data, SeaTunnelDataType<?> seaTunnelDataType
}
}
return data;
case DECIMAL:
BigDecimal decimal = (BigDecimal) data;
return ByteBuffer.wrap(decimal.unscaledValue().toByteArray());
case DATE:
LocalDate localDate = (LocalDate) data;
return localDate.toEpochDay();
case BYTES:
return ByteBuffer.wrap((byte[]) data);
case ARRAY:
Expand All @@ -211,17 +127,15 @@ private Object resolveObject(Object data, SeaTunnelDataType<?> seaTunnelDataType
((SeaTunnelRowType) seaTunnelDataType).getFieldTypes();
String[] fieldNames = ((SeaTunnelRowType) seaTunnelDataType).getFieldNames();
Schema recordSchema =
buildAvroSchemaWithRowType((SeaTunnelRowType) seaTunnelDataType);
SeaTunnelRowTypeToAvroSchemaConverter.buildAvroSchemaWithRowType(
(SeaTunnelRowType) seaTunnelDataType);
GenericRecordBuilder recordBuilder = new GenericRecordBuilder(recordSchema);
for (int i = 0; i < fieldNames.length; i++) {
recordBuilder.set(
fieldNames[i].toLowerCase(),
resolveObject(seaTunnelRow.getField(i), fieldTypes[i]));
}
return recordBuilder.build();
case TIMESTAMP:
LocalDateTime dateTime = (LocalDateTime) data;
return (dateTime).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli();
default:
String errorMsg =
String.format(
Expand Down
Loading

0 comments on commit 4c05290

Please sign in to comment.