-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update strided implementation #191
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this change a lot, as this is definitely functionality that can be re-used in many other places.
The only comment about contents I can come up with is if the stridedtensorcontract!
implementation, which basically checks the memory costs and then dispatches to blas_contract!
really is specific to Strided, but I guess the memory model depends on the strided way of doing permutations.
Otherwise, maybe we could consider moving the blas_contract!
functions into a different file? In principle, the implementation that's now here is not even BLAS specific, it really is just transpose-transpose-gemm-transpose, which should also work for abstractarray, and maybe no longer really belongs in the strided file.
src/implementation/strided.jl
Outdated
@@ -228,7 +231,7 @@ function blas_contract!(C, A, pA, B, pB, pAB, α, β, allocator) | |||
C_ = SV(tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)) | |||
_unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB), | |||
one(TC), zero(TC)) | |||
stridedtensoradd!(C, C_, pAB, α, β, StridedNative(), allocator) | |||
stridedtensoradd!(C, C_, pAB, α, β, backend, allocator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be stridedtensoradd!
or just tensoradd!
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, if this is to work generally, it should be tensoradd!
Yes, I think making a separate |
Ok, the errors seem CUDA related on x86 platforms, only in the latest version (1.11). I guess we can ignore this and this is ready to be merged? |
I think so, yes |
This PR contains two changes:
It adds the backend argument throughout the lower methods of the StridedBLAS implementation. This enables the lower methods (like
blas_contract!
) to be reused by other backends that only provide a specialised tensor permutation implementation, in particular the upcoming HPTT.It adds the allocator argument to the specialised
stridedtensorcontract!
method where all arguments are simply matrices. Without this argument, this function would never be called as the argument is inserted higher up in the chain, and then the general method for arbitrary rank tensors ends up being used.