Skip to content

Commit

Permalink
[Improve][SQL-Transform] support struct query
Browse files Browse the repository at this point in the history
  • Loading branch information
liunaijie committed Mar 11, 2024
1 parent 7f051b2 commit 175defe
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.seatunnel.transform.sql.zeta;

import org.apache.seatunnel.api.table.type.DecimalType;
import org.apache.seatunnel.api.table.type.MapType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
Expand Down Expand Up @@ -56,6 +58,7 @@
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class ZetaSQLFunction {
// ============================internal functions=====================
Expand Down Expand Up @@ -199,8 +202,29 @@ public Object computeForValue(Expression expression, Object[] inputFields) {
return ((StringValue) expression).getValue();
}
if (expression instanceof Column) {
int idx = inputRowType.indexOf(((Column) expression).getColumnName());
return inputFields[idx];
Column columnExp = (Column) expression;
try {
String columnName = columnExp.getColumnName();
int idx = inputRowType.indexOf(columnName);
return inputFields[idx];
} catch (IllegalArgumentException e) {
String fullyQualifiedName = columnExp.getFullyQualifiedName();
String[] columnNames = fullyQualifiedName.split("\\.");
int deep = columnNames.length;
SeaTunnelDataType parDataType = inputRowType;
SeaTunnelRow parRowValues = new SeaTunnelRow(inputFields);
Object res = parRowValues;
for (int i = 0; i < deep; i++) {
if (parDataType instanceof MapType) {
return ((Map) res).get(columnNames[i]);
}
parRowValues = (SeaTunnelRow) res;
int idx = ((SeaTunnelRowType) parDataType).indexOf(columnNames[i]);
parDataType = ((SeaTunnelRowType) parDataType).getFieldType(idx);
res = parRowValues.getFields()[idx];
}
return res;
}
}
if (expression instanceof Function) {
Function function = (Function) expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.seatunnel.api.table.type.BasicType;
import org.apache.seatunnel.api.table.type.DecimalType;
import org.apache.seatunnel.api.table.type.LocalTimeType;
import org.apache.seatunnel.api.table.type.MapType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
Expand Down Expand Up @@ -101,8 +102,32 @@ public SeaTunnelDataType<?> getExpressionType(Expression expression) {
return BasicType.STRING_TYPE;
}
if (expression instanceof Column) {
String columnName = ((Column) expression).getColumnName();
return inputRowType.getFieldType(inputRowType.indexOf(columnName));
Column columnExp = (Column) expression;
try {
String columnName = columnExp.getColumnName();
return inputRowType.getFieldType(inputRowType.indexOf(columnName));
} catch (IllegalArgumentException e) {
// fullback logical to handel a.b.c query.
String fullyQualifiedName = columnExp.getFullyQualifiedName();
String[] columnNames = fullyQualifiedName.split("\\.");
int deep = columnNames.length;
SeaTunnelRowType parRowType = inputRowType;
SeaTunnelDataType<?> filedTypeRes = null;
for (int i = 0; i < deep; i++) {
filedTypeRes = parRowType.getFieldType(parRowType.indexOf(columnNames[i]));
if (filedTypeRes instanceof SeaTunnelRowType) {
parRowType = (SeaTunnelRowType) filedTypeRes;
} else if (filedTypeRes instanceof MapType) {
// for map type. only support it's the latest struct.
if (i != deep - 1) {
throw new IllegalArgumentException(
"For now, we only support map struct is the latest struct in inner query function! Please modify your query!");
}
return ((MapType<?, ?>) filedTypeRes).getValueType();
}
}
return filedTypeRes;
}
}
if (expression instanceof Function) {
return getFunctionType((Function) expression);
Expand Down

0 comments on commit 175defe

Please sign in to comment.