Skip to content

Commit

Permalink
fix extension issues on julia 1.11.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Oct 18, 2024
1 parent 3433bc9 commit 42e9c9c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
authors = ["Lukas Devos <[email protected]>", "Maarten Van Damme <[email protected]>", "Jutho Haegeman <[email protected]>"]
version = "5.0.2"
version = "5.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
12 changes: 4 additions & 8 deletions ext/TensorOperationsBumperExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@ function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp},
end
end

function TensorOperations.blas_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
allocator::Union{SlabBuffer,AllocBuffer})
function TensorOperations.blas_contract!(C, A, pA, B, pB, pAB, α, β,
backend, allocator::Union{SlabBuffer,AllocBuffer})
@no_escape allocator begin
C = Base.@invoke TensorOperations.blas_contract!(C::Any, A::Any, pA::Any,
conjA::Any,
B::Any, pB::Any,
conjB::Any, pAB::Any, α::Any,
β::Any,
allocator::Any)
C = Base.@invoke TensorOperations.blas_contract!(C, A, pA, B, pB, pAB, α, β,
backend, allocator::Any)
end
return C
end
Expand Down
139 changes: 82 additions & 57 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,34 @@ end

# The current `rrule` design makes sure that the implementation for custom types does
# not need to support the backend or allocator arguments
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend, allocator)
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend)
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
# return val, ΔC -> (pb(ΔC)..., NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# α::Number, β::Number)
# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
# end
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C,
A, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number,
backend, allocator)
val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
end
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C,
A, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number,
backend)
val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
return val, ΔC -> (pb(ΔC)..., NoTangent())
end
function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C,
A, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
ba...)
return _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
end
function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
C′ = tensoradd!(copy(C), A, pA, conjA, α, β, ba...)
Expand Down Expand Up @@ -98,40 +105,50 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
((), ()), One(), ba...))
return projectβ(_dβ)
end
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
end

return C′, pullback
end

# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# B, pB::Index2Tuple, conjB::Bool,
# pAB::Index2Tuple,
# α::Number, β::Number,
# backend, allocator)
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
# (backend, allocator))
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# B, pB::Index2Tuple, conjB::Bool,
# pAB::Index2Tuple,
# α::Number, β::Number,
# backend)
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
# return val, ΔC -> (pb(ΔC)..., NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# B, pB::Index2Tuple, conjB::Bool,
# pAB::Index2Tuple,
# α::Number, β::Number)
# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
# end
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C,
A, pA::Index2Tuple, conjA::Bool,
B, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number,
backend, allocator)
val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
(backend, allocator))
return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
end
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C,
A, pA::Index2Tuple, conjA::Bool,
B, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number,
backend)
val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
return val, ΔC -> (pb(ΔC)..., NoTangent())
end
function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C,
A, pA::Index2Tuple, conjA::Bool,
B, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number)
return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
ba...)
return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
end
function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
C′ = tensorcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
Expand Down Expand Up @@ -187,32 +204,39 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
((), ()), One(), ba...))
return projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC,
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
NoTangent(), dα, dβ
NoTangent(), dα, dβ, dba...
end

return C′, pullback
end

# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend, allocator)
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend)
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
# return val, ΔC -> (pb(ΔC)..., NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
# α::Number, β::Number)
# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
# end
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number,
backend, allocator)
val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
end
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number,
backend)
val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
return val, ΔC -> (pb(ΔC)..., NoTangent())
end
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number)
return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
ba...)
return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
end
function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
C′ = tensortrace!(copy(C), A, p, q, conjA, α, β, ba...)
Expand Down Expand Up @@ -253,7 +277,8 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
((), ()), One(), ba...))
return projectβ(_dβ)
end
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
end

return C′, pullback
Expand Down
18 changes: 10 additions & 8 deletions ext/TensorOperationscuTENSORExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ using CUDA.Adapt: adapt
using Strided
using TupleTools: TupleTools as TT

const StridedViewsCUDAExt = @static if isdefined(Base, :get_extension)
Base.get_extension(Strided.StridedViews, :StridedViewsCUDAExt)
else
Strided.StridedViews.StridedViewsCUDAExt
end
isnothing(StridedViewsCUDAExt) && error("StridedViewsCUDAExt not found")
# Disallowed paradigm from Julia 1.11.1 onwards:
# const StridedViewsCUDAExt = @static if isdefined(Base, :get_extension)
# Base.get_extension(Strided.StridedViews, :StridedViewsCUDAExt)
# else
# Strided.StridedViews.StridedViewsCUDAExt
# end
# isnothing(StridedViewsCUDAExt) && error("StridedViewsCUDAExt not found")

# Literal copy of the StridedViewsCUDAExt module
const CuStridedView{T,N,A<:CuArray{T}} = StridedView{T,N,A}

#-------------------------------------------------------------------------------------------
# @cutensor macro
Expand All @@ -53,8 +57,6 @@ end
#-------------------------------------------------------------------------------------------
# Backend selection and passing
#-------------------------------------------------------------------------------------------
const CuStridedView = StridedViewsCUDAExt.CuStridedView

# A Base wrapper over `CuArray` will first pass via the `select_backend` methods for
# `AbstractArray` and be converted into a `StridedView` if it satisfies `isstrided`. Hence,
# we only need to capture `CuStridedView` here.
Expand Down

0 comments on commit 42e9c9c

Please sign in to comment.