Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jan 28, 2025
1 parent 684afed commit 803971e
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions spark/src/main/scala/org/apache/comet/serde/hash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,10 @@ object CometXxHash64 extends CometExpressionSerde {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val hash = expr.asInstanceOf[XxHash64]
for (child <- hash.children) {
child.dataType match {
case dt: DecimalType if dt.precision > 18 =>
// Spark converts decimals with precision > 18 into
// Java BigDecimal before hashing
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
return None
case dt if !supportedDataType(dt) =>
withInfo(expr, s"Unsupported datatype $dt")
return None
case _ =>
}
if (!HashUtils.isSupportedType(expr)) {
return None
}
val hash = expr.asInstanceOf[XxHash64]
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
Expand All @@ -60,27 +50,36 @@ object Murmur3Hash extends CometExpressionSerde {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
if (!HashUtils.isSupportedType(expr)) {
return None
}
val hash = expr.asInstanceOf[Murmur3Hash]
for (child <- hash.children) {
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(IntegerType).get)
.setIntVal(hash.seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)
}
}

private object HashUtils {
def isSupportedType(expr: Expression): Boolean = {
for (child <- expr.children) {
child.dataType match {
case dt: DecimalType if dt.precision > 18 =>
// Spark converts decimals with precision > 18 into
// Java BigDecimal before hashing
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
return None
return false
case dt if !supportedDataType(dt) =>
withInfo(expr, s"Unsupported datatype $dt")
return None
return false
case _ =>
}
}
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(IntegerType).get)
.setIntVal(hash.seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)
true
}
}

0 comments on commit 803971e

Please sign in to comment.