Skip to content

Commit

Permalink
force strided specialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Oct 18, 2024
1 parent 42e9c9c commit 2a56132
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/implementation/strided.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ end
#-------------------------------------------------------------------------------------------
# StridedView implementation
#-------------------------------------------------------------------------------------------
struct Adder end
(::Adder)(x, y) = VectorInterface.add(x, y)
struct Scaler{T}
α::T
end
(s::Scaler)(x) = scale(x, s.α)
(s::Scaler)(x, y) = scale(x * y, s.α)

function stridedtensoradd!(C::StridedView,
A::StridedView, pA::Index2Tuple,
α::Number, β::Number,
Expand All @@ -102,9 +110,7 @@ function stridedtensoradd!(C::StridedView,
end

A′ = permutedims(A, linearize(pA))
op1 = Base.Fix2(scale, α)
op2 = Base.Fix2(scale, β)
Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), size(C), (C, A′))
return C
end

Expand All @@ -125,9 +131,7 @@ function stridedtensortrace!(C::StridedView,
newsize = (size(C)..., tracesize...)

A′ = SV(A.parent, newsize, newstrides, A.offset, A.op)
op1 = Base.Fix2(scale, α)
op2 = Base.Fix2(scale, β)
Strided._mapreducedim!(op1, +, op2, newsize, (C, A′))
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), newsize, (C, A′))
return C
end

Expand Down Expand Up @@ -170,8 +174,6 @@ function stridedtensorcontract!(C::StridedView,
(osizeA..., osizeB..., one.(csizeA)...))
tsize = (osizeA..., osizeB..., csizeA...)

op1 = Base.Fix2(scale, α) *
op2 = Base.Fix2(scale, β)
Strided._mapreducedim!(op1, +, op2, tsize, (CS, AS, BS))
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), tsize, (CS, AS, BS))
return C
end

2 comments on commit 2a56132

@Jutho
Copy link
Owner Author

@Jutho Jutho commented on 2a56132 Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • remove small type instability in Strided implementations
  • fix broken extension behaviour after Julia 1.11.1 update

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117609

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v5.1.0 -m "<description of version>" 2a56132fde2151fa8bb5fb54357a14f19cee135a
git push origin v5.1.0

Please sign in to comment.