Skip to content

Commit

Permalink
Make sure that ȳ's output is a compatible float type
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed May 30, 2019
1 parent aa4e64a commit 9c2c05e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/back.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corr
"""
function jacobian(f, x::AbstractVector)
y::AbstractVector, back = forward(f, x)
z = float(zero(eltype(data(y))))
# Using broadcasting so that output of `ȳ` is a GPU array if `y` is so:
(i) = ((j, _) -> i == j).(1:length(y), y)
(i) = ((j, _) -> i == j).(1:length(y), y) .+ z
vcat([transpose(back((i))[1]) for i = 1:length(y)]...)
end

Expand Down

0 comments on commit 9c2c05e

Please sign in to comment.