Skip to content

Commit

Permalink
onnx: workaround pow with negative base (#2439)
Browse files Browse the repository at this point in the history
* onnx: workaround pow with negative base

rather than fully defining pow in the cpu backend (as in #2318),
this implements a much smaller change which is sufficient to evaluate silero-vad
onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so
evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`.

* PR: use Tensor::powf insead

powf correctly handles a negative base.
  • Loading branch information
shua authored Aug 22, 2024
1 parent 6070278 commit a8288b7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,15 @@ fn simple_eval_(
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
}
"Exp" => {
let xs = get(&node.input[0])?;
Expand Down

0 comments on commit a8288b7

Please sign in to comment.