diff --git a/src/back.jl b/src/back.jl index 329e5fcc..db9669f5 100644 --- a/src/back.jl +++ b/src/back.jl @@ -173,11 +173,8 @@ Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corr """ function jacobian(f, x::AbstractVector) y::AbstractVector, back = forward(f, x) - function ȳ(i) - δ = fill!(similar(y, Bool), false) - δ[i] = true - return δ - end + # Using broadcasting so that output of `ȳ` is a GPU array if `y` is so: + ȳ(i) = ((j, _) -> i == j).(1:length(y), y) vcat([transpose(back(ȳ(i))[1]) for i = 1:length(y)]...) end