diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 24da23a20f..d6beb70e7b 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -169,8 +169,22 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; } - Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?, - Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, + Op::ScatterAdd(init, indexes, src, dim) => { + let init_sum_grad = grads.or_insert(init)?; + *init_sum_grad = init_sum_grad.add(&grad)?; + + let src_grad = grad.gather(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } + Op::IndexAdd(init, indexes, src, dim) => { + let init_sum_grad = grads.or_insert(init)?; + *init_sum_grad = init_sum_grad.add(&grad)?; + + let src_grad = grad.index_select(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } Op::IndexSelect(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?; @@ -228,7 +242,7 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad)?; } - Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }), + Op::Cmp(_args, _) => {} Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { let node = broadcast_back(arg, node, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?; @@ -268,7 +282,12 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? } - Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?, + Op::Unary(arg, UnaryOp::Abs) => { + let sum_grad = grads.or_insert(arg)?; + let ones = arg.ones_like()?; + let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?); + *sum_grad = sum_grad.add(&(&grad * abs_grad)?)? + } Op::Unary(arg, UnaryOp::Exp) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&(&grad * *node)?)? @@ -303,12 +322,8 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Reduce(_, ReduceOp::ArgMin, _) => { - Err(Error::BackwardNotSupported { op: "argmin" })? - } - Op::Reduce(_, ReduceOp::ArgMax, _) => { - Err(Error::BackwardNotSupported { op: "argmax" })? - } + Op::Reduce(_, ReduceOp::ArgMin, _) => {} + Op::Reduce(_, ReduceOp::ArgMax, _) => {} Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?, Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; @@ -316,7 +331,11 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, - Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?, + Op::Unary(arg, UnaryOp::Relu) => { + let sum_grad = grads.or_insert(arg)?; + let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; + *sum_grad = sum_grad.add(&(&grad * relu_grad)?)? + } Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, Op::CustomOp1(arg, c) => { if let Some(arg_grad) = c.bwd(arg, node, &grad)? {