Skip to content

Commit

Permalink
Fix an issue with unsorted deleteat
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 3, 2024
1 parent 2da3b04 commit c3b4fae
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ext/TensorOperationscuTENSORExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,19 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::ModeType,
# TODO: check if this can be avoided, available in caller
# TODO: cuTENSOR will allocate sizes and strides anyways, could use that here
p, q = TO.trace_indices(tuple(Ainds...), tuple(Cinds...))

qsorted = TT.sort(q[2])
# add strides of cindA2 to strides of cindA1 -> selects diagonal
stA = strides(A)
for (i, j) in zip(q...)
stA = Base.setindex(stA, stA[i] + stA[j], i)
end
szA = TT.deleteat(size(A), q[2])
stA′ = TT.deleteat(stA, q[2])
szA = TT.deleteat(size(A), qsorted)
stA′ = TT.deleteat(stA, qsorted)

descA = CuTensorDescriptor(A; size=szA, strides=stA′)
descC = CuTensorDescriptor(C)

modeA = collect(Cint, deleteat!(Ainds, q[2]))
modeA = collect(Cint, deleteat!(Ainds, qsorted))
modeC = collect(Cint, Cinds)

actual_compute_type = if compute_type === nothing
Expand Down

0 comments on commit c3b4fae

Please sign in to comment.