Skip to content

Commit

Permalink
Support backprop for a few more ops. (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Jul 26, 2023
1 parent 4f92420 commit 89ba005
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,22 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::ScatterAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;

let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;

let src_grad = grad.index_select(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
Expand Down Expand Up @@ -228,7 +242,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
Expand Down Expand Up @@ -268,7 +282,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
let ones = arg.ones_like()?;
let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
*sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
}
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
Expand Down Expand Up @@ -303,20 +322,20 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {
Err(Error::BackwardNotSupported { op: "argmin" })?
}
Op::Reduce(_, ReduceOp::ArgMax, _) => {
Err(Error::BackwardNotSupported { op: "argmax" })?
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
Expand Down

0 comments on commit 89ba005

Please sign in to comment.