diff --git a/docs/en/transform-v2/sql-functions.md b/docs/en/transform-v2/sql-functions.md index dd8b8dbfdd1..5f74efcf6af 100644 --- a/docs/en/transform-v2/sql-functions.md +++ b/docs/en/transform-v2/sql-functions.md @@ -891,3 +891,19 @@ Returns NULL if 'a' is equal to 'b', otherwise 'a'. Example: NULLIF(A, B) + +## Conditional Functions + +### CASE + +```sql +CASE value WHEN compare_value THEN result [WHEN compare_value THEN result ...] [ELSE result] END +CASE WHEN condition THEN result [WHEN condition THEN result ...] [ELSE result] END +``` + +Converts a value to another. + +Example: + +```CASE state WHEN 1 THEN 'on' ELSE 'off' END``` +```CASE WHEN state>0 THEN 1 ELSE 0 END``` diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java index d54a2addaf5..8769053a46d 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/java/org/apache/seatunnel/e2e/transform/TestSQLIT.java @@ -52,5 +52,17 @@ public void testSQLTransform(TestContainer container) throws IOException, Interr Container.ExecResult sqlAllColumns = container.executeJob("/sql_transform/sql_all_columns.conf"); Assertions.assertEquals(0, sqlAllColumns.getExitCode()); + + // region case when + Container.ExecResult caseFieldWhenCondition = + container.executeJob("/sql_transform/case_field_when_condition.conf"); + Assertions.assertEquals(0, caseFieldWhenCondition.getExitCode()); + Container.ExecResult caseFieldWhenValue = + container.executeJob("/sql_transform/case_field_when_value.conf"); + Assertions.assertEquals(0, caseFieldWhenValue.getExitCode()); + Container.ExecResult caseWhenCondition = + container.executeJob("/sql_transform/case_when_condition.conf"); + Assertions.assertEquals(0, caseWhenCondition.getExitCode()); + // endregion } } diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_field_when_condition.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_field_when_condition.conf new file mode 100644 index 00000000000..1b24a1df7c8 --- /dev/null +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_field_when_condition.conf @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +###### +###### This config file is a demonstration of streaming processing in seatunnel config +###### + +env { + execution.parallelism = 1 + job.mode = "BATCH" + checkpoint.interval = 10000 +} + +source { + FakeSource { + result_table_name = "fake" + schema = { + fields { + state = "int" + name = "string" + price = "double" + } + } + rows = [ + {fields = [-6, "javalover123", 134.22], kind = INSERT} + ] + } +} + +transform { + Sql { + source_table_name = "fake" + result_table_name = "fake1" + query = "select case state when state>0 then 1 else 0 end show_tag, name, price from fake" + } +} + +sink { + Console { + source_table_name = "fake1" + } + Assert { + source_table_name = "fake1" + rules = { + field_rules = [ + { + field_name = "show_tag" + field_type = "int" + field_value = [ + {equals_to = 0} + ] + }, + { + field_name = "name" + field_type = "string" + field_value = [ + {equals_to = "javalover123"} + ] + }, + { + field_name = "price" + field_type = "double" + field_value = [ + {equals_to = 134.22} + ] + } + ] + } + } +} \ No newline at end of file diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_field_when_value.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_field_when_value.conf new file mode 100644 index 00000000000..d0f718c0153 --- /dev/null +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_field_when_value.conf @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +###### +###### This config file is a demonstration of streaming processing in seatunnel config +###### + +env { + execution.parallelism = 1 + job.mode = "BATCH" + checkpoint.interval = 10000 +} + +source { + FakeSource { + result_table_name = "fake" + schema = { + fields { + state = "bigint" + name = "string" + price = "double" + } + } + rows = [ + {fields = [6, "javalover123", 134.22], kind = INSERT} + ] + } +} + +transform { + Sql { + source_table_name = "fake" + result_table_name = "fake1" + query = "select case state when 6 then 1 else 0 end show_tag, name, price from fake" + } +} + +sink { + Console { + source_table_name = "fake1" + } + Assert { + source_table_name = "fake1" + rules = { + field_rules = [ + { + field_name = "show_tag" + field_type = "int" + field_value = [ + {equals_to = 1} + ] + }, + { + field_name = "name" + field_type = "string" + field_value = [ + {equals_to = "javalover123"} + ] + }, + { + field_name = "price" + field_type = "double" + field_value = [ + {equals_to = 134.22} + ] + } + ] + } + } +} \ No newline at end of file diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_when_condition.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_when_condition.conf new file mode 100644 index 00000000000..799ff24c9f2 --- /dev/null +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-2/src/test/resources/sql_transform/case_when_condition.conf @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +###### +###### This config file is a demonstration of streaming processing in seatunnel config +###### + +env { + execution.parallelism = 1 + job.mode = "BATCH" + checkpoint.interval = 10000 +} + +source { + FakeSource { + result_table_name = "fake" + schema = { + fields { + state = "int" + name = "string" + price = "double" + } + } + rows = [ + {fields = [6, "javalover123", 134.22], kind = INSERT} + ] + } +} + +transform { + Sql { + source_table_name = "fake" + result_table_name = "fake1" + query = "select case when state>0 then 1 else 0 end show_tag, name, price from fake" + } +} + +sink { + Console { + source_table_name = "fake1" + } + Assert { + source_table_name = "fake1" + rules = { + field_rules = [ + { + field_name = "show_tag" + field_type = "int" + field_value = [ + {equals_to = 1} + ] + }, + { + field_name = "name" + field_type = "string" + field_value = [ + {equals_to = "javalover123"} + ] + }, + { + field_name = "price" + field_type = "double" + field_value = [ + {equals_to = 134.22} + ] + } + ] + } + } +} \ No newline at end of file diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java index 55fbe04cf13..0b988af523c 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLEngine.java @@ -70,7 +70,7 @@ public void init(String inputTableName, SeaTunnelRowType inputRowType, String sq this.zetaSQLType = new ZetaSQLType(inputRowType, udfList); this.zetaSQLFunction = new ZetaSQLFunction(inputRowType, zetaSQLType, udfList); - this.zetaSQLFilter = new ZetaSQLFilter(zetaSQLFunction); + this.zetaSQLFilter = new ZetaSQLFilter(zetaSQLFunction, zetaSQLType); parseSQL(); } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFilter.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFilter.java index 77b59c51bc6..52dbe90d526 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFilter.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFilter.java @@ -17,6 +17,7 @@ package org.apache.seatunnel.transform.sql.zeta; +import org.apache.seatunnel.api.table.type.BasicType; import org.apache.seatunnel.common.exception.CommonErrorCode; import org.apache.seatunnel.transform.exception.TransformException; @@ -47,9 +48,15 @@ public class ZetaSQLFilter { private final ZetaSQLFunction zetaSQLFunction; + private final ZetaSQLType zetaSQLType; - public ZetaSQLFilter(ZetaSQLFunction zetaSQLFunction) { + public ZetaSQLFilter(ZetaSQLFunction zetaSQLFunction, ZetaSQLType zetaSQLType) { this.zetaSQLFunction = zetaSQLFunction; + this.zetaSQLType = zetaSQLType; + } + + public boolean isConditionExpr(Expression expression) { + return BasicType.BOOLEAN_TYPE.equals(zetaSQLType.getExpressionType(expression)); } public boolean executeFilter(Expression whereExpr, Object[] inputFields) { @@ -185,7 +192,7 @@ private Pair executeComparisonOperator( return Pair.of(leftVal, rightVal); } - private boolean equalsToExpr(Pair pair) { + boolean equalsToExpr(Pair pair) { Object leftVal = pair.getLeft(); Object rightVal = pair.getRight(); if (leftVal == null || rightVal == null) { diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java index 4e9bf06e12c..6630e09043a 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java @@ -28,7 +28,10 @@ import org.apache.seatunnel.transform.sql.zeta.functions.StringFunction; import org.apache.seatunnel.transform.sql.zeta.functions.SystemFunction; +import org.apache.commons.lang3.tuple.Pair; + import net.sf.jsqlparser.expression.BinaryExpression; +import net.sf.jsqlparser.expression.CaseExpression; import net.sf.jsqlparser.expression.CastExpression; import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.Expression; @@ -39,6 +42,7 @@ import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.TimeKeyExpression; +import net.sf.jsqlparser.expression.WhenClause; import net.sf.jsqlparser.expression.operators.arithmetic.Addition; import net.sf.jsqlparser.expression.operators.arithmetic.Concat; import net.sf.jsqlparser.expression.operators.arithmetic.Division; @@ -162,6 +166,7 @@ public class ZetaSQLFunction { public static final String NULLIF = "NULLIF"; private final SeaTunnelRowType inputRowType; + private final ZetaSQLFilter zetaSQLFilter; private final ZetaSQLType zetaSQLType; private final List udfList; @@ -170,6 +175,7 @@ public ZetaSQLFunction( SeaTunnelRowType inputRowType, ZetaSQLType zetaSQLType, List udfList) { this.inputRowType = inputRowType; this.zetaSQLType = zetaSQLType; + this.zetaSQLFilter = new ZetaSQLFilter(this, zetaSQLType); this.udfList = udfList; } @@ -220,6 +226,15 @@ public Object computeForValue(Expression expression, Object[] inputFields) { Parenthesis parenthesis = (Parenthesis) expression; return computeForValue(parenthesis.getExpression(), inputFields); } + if (zetaSQLFilter.isConditionExpr(expression)) { + return zetaSQLFilter.executeFilter(expression, inputFields); + } + if (expression instanceof CaseExpression) { + CaseExpression caseExpression = (CaseExpression) expression; + final Object value = executeCaseExpr(caseExpression, inputFields); + SeaTunnelDataType type = zetaSQLType.getExpressionType(expression); + return SystemFunction.castAs(value, type); + } if (expression instanceof BinaryExpression) { return executeBinaryExpr((BinaryExpression) expression, inputFields); } @@ -435,6 +450,23 @@ public Object executeTimeKeyExpr(String timeKeyExpr) { String.format("Unsupported TimeKey expression: %s", timeKeyExpr)); } + public Object executeCaseExpr(CaseExpression caseExpression, Object[] inputFields) { + Expression switchExpr = caseExpression.getSwitchExpression(); + Object switchValue = switchExpr == null ? null : computeForValue(switchExpr, inputFields); + for (WhenClause whenClause : caseExpression.getWhenClauses()) { + final Object when = computeForValue(whenClause.getWhenExpression(), inputFields); + // match: case [column] when column1 compare other, add by javalover123 + boolean isComparison = zetaSQLFilter.isConditionExpr(whenClause.getWhenExpression()); + if (isComparison && (boolean) when) { + return computeForValue(whenClause.getThenExpression(), inputFields); + } else if (!isComparison && zetaSQLFilter.equalsToExpr(Pair.of(switchValue, when))) { + return computeForValue(whenClause.getThenExpression(), inputFields); + } + } + final Expression elseExpression = caseExpression.getElseExpression(); + return elseExpression == null ? null : computeForValue(elseExpression, inputFields); + } + public Object executeCastExpr(CastExpression castExpression, Object arg) { String dataType = castExpression.getType().getDataType(); List args = new ArrayList<>(2); diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java index 51cbb10e6d2..31548867ae2 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java @@ -26,7 +26,10 @@ import org.apache.seatunnel.common.exception.CommonErrorCode; import org.apache.seatunnel.transform.exception.TransformException; +import org.apache.commons.collections4.CollectionUtils; + import net.sf.jsqlparser.expression.BinaryExpression; +import net.sf.jsqlparser.expression.CaseExpression; import net.sf.jsqlparser.expression.CastExpression; import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.Expression; @@ -37,12 +40,22 @@ import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.TimeKeyExpression; +import net.sf.jsqlparser.expression.WhenClause; import net.sf.jsqlparser.expression.operators.arithmetic.Concat; +import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.expression.operators.relational.InExpression; +import net.sf.jsqlparser.expression.operators.relational.IsNullExpression; +import net.sf.jsqlparser.expression.operators.relational.LikeExpression; import net.sf.jsqlparser.schema.Column; import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; import java.util.List; +import java.util.stream.Collectors; public class ZetaSQLType { public static final String DECIMAL = "DECIMAL"; @@ -106,6 +119,17 @@ public SeaTunnelDataType getExpressionType(Expression expression) { if (expression instanceof Concat) { return BasicType.STRING_TYPE; } + if (expression instanceof CaseExpression) { + return getCaseType((CaseExpression) expression); + } + if (expression instanceof ComparisonOperator + || expression instanceof IsNullExpression + || expression instanceof InExpression + || expression instanceof LikeExpression + || expression instanceof AndExpression + || expression instanceof OrExpression) { + return BasicType.BOOLEAN_TYPE; + } if (expression instanceof CastExpression) { return getCastType((CastExpression) expression); } @@ -149,6 +173,73 @@ public SeaTunnelDataType getExpressionType(Expression expression) { String.format("Unsupported SQL Expression: %s ", expression.toString())); } + public boolean isNumberType(SqlType type) { + return type.compareTo(SqlType.TINYINT) >= 0 && type.compareTo(SqlType.DECIMAL) <= 0; + } + + public SeaTunnelDataType getMaxType( + SeaTunnelDataType leftType, SeaTunnelDataType rightType) { + if (leftType == null || BasicType.VOID_TYPE.equals(leftType)) { + return rightType; + } + if (rightType == null || BasicType.VOID_TYPE.equals(rightType)) { + return leftType; + } + if (leftType.equals(rightType)) { + return leftType; + } + + final boolean isAllNumber = + isNumberType(leftType.getSqlType()) && isNumberType(rightType.getSqlType()); + if (!isAllNumber) { + throw new TransformException( + CommonErrorCode.UNSUPPORTED_OPERATION, + leftType + " type not compatible " + rightType); + } + + if (leftType.getSqlType() == SqlType.DECIMAL || rightType.getSqlType() == SqlType.DECIMAL) { + int precision = 0; + int scale = 0; + if (leftType.getSqlType() == SqlType.DECIMAL) { + DecimalType decimalType = (DecimalType) leftType; + precision = decimalType.getPrecision(); + scale = decimalType.getScale(); + } + if (rightType.getSqlType() == SqlType.DECIMAL) { + DecimalType decimalType = (DecimalType) rightType; + precision = Math.max(decimalType.getPrecision(), precision); + scale = Math.max(decimalType.getScale(), scale); + } + return new DecimalType(precision, scale); + } + return leftType.getSqlType().compareTo(rightType.getSqlType()) <= 0 ? rightType : leftType; + } + + public SeaTunnelDataType getMaxType(Collection> types) { + if (CollectionUtils.isEmpty(types)) { + throw new TransformException( + CommonErrorCode.UNSUPPORTED_OPERATION, "getMaxType parameter is null"); + } + Iterator> iterator = types.iterator(); + SeaTunnelDataType result = iterator.next(); + while (iterator.hasNext()) { + result = getMaxType(result, iterator.next()); + } + return result; + } + + private SeaTunnelDataType getCaseType(CaseExpression caseExpression) { + final Collection> types = + caseExpression.getWhenClauses().stream() + .map(WhenClause::getThenExpression) + .map(this::getExpressionType) + .collect(Collectors.toSet()); + if (caseExpression.getElseExpression() != null) { + types.add(getExpressionType(caseExpression.getElseExpression())); + } + return getMaxType(types); + } + private SeaTunnelDataType getCastType(CastExpression castExpression) { String dataType = castExpression.getType().getDataType(); switch (dataType.toUpperCase()) { diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/SystemFunction.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/SystemFunction.java index fbd53b655e3..59275af98d8 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/SystemFunction.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/SystemFunction.java @@ -17,6 +17,8 @@ package org.apache.seatunnel.transform.sql.zeta.functions; +import org.apache.seatunnel.api.table.type.DecimalType; +import org.apache.seatunnel.api.table.type.SeaTunnelDataType; import org.apache.seatunnel.common.exception.CommonErrorCode; import org.apache.seatunnel.transform.exception.TransformException; @@ -25,6 +27,7 @@ import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; +import java.util.ArrayList; import java.util.List; public class SystemFunction { @@ -60,6 +63,18 @@ public static Object nullif(List args) { return v1; } + public static Object castAs(Object arg, SeaTunnelDataType type) { + final ArrayList args = new ArrayList<>(4); + args.add(arg); + args.add(type.getSqlType().toString()); + if (DecimalType.class.equals(type.getClass())) { + final DecimalType decimalType = (DecimalType) type; + args.add(decimalType.getPrecision()); + args.add(decimalType.getScale()); + } + return castAs(args); + } + public static Object castAs(List args) { Object v1 = args.get(0); String v2 = (String) args.get(1);