diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala index fa40448dd..e695f8b1b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/hash.scala +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -30,20 +30,10 @@ object CometXxHash64 extends CometExpressionSerde { expr: Expression, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val hash = expr.asInstanceOf[XxHash64] - for (child <- hash.children) { - child.dataType match { - case dt: DecimalType if dt.precision > 18 => - // Spark converts decimals with precision > 18 into - // Java BigDecimal before hashing - withInfo(expr, s"Unsupported datatype: $dt (precision > 18)") - return None - case dt if !supportedDataType(dt) => - withInfo(expr, s"Unsupported datatype $dt") - return None - case _ => - } + if (!HashUtils.isSupportedType(expr)) { + return None } + val hash = expr.asInstanceOf[XxHash64] val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding)) val seedBuilder = ExprOuterClass.Literal .newBuilder() @@ -60,27 +50,36 @@ object Murmur3Hash extends CometExpressionSerde { expr: Expression, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { + if (!HashUtils.isSupportedType(expr)) { + return None + } val hash = expr.asInstanceOf[Murmur3Hash] - for (child <- hash.children) { + val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(IntegerType).get) + .setIntVal(hash.seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + } +} + +private object HashUtils { + def isSupportedType(expr: Expression): Boolean = { + for (child <- expr.children) { child.dataType match { case dt: DecimalType if dt.precision > 18 => // Spark converts decimals with precision > 18 into // Java BigDecimal before hashing withInfo(expr, s"Unsupported datatype: $dt (precision > 18)") - return None + return false case dt if !supportedDataType(dt) => withInfo(expr, s"Unsupported datatype $dt") - return None + return false case _ => } } - val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding)) - val seedBuilder = ExprOuterClass.Literal - .newBuilder() - .setDatatype(serializeDataType(IntegerType).get) - .setIntVal(hash.seed) - val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) - // the seed is put at the end of the arguments - scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + true } }