diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 627e0126d1fe89..f2681400b786a9 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -139,15 +139,14 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x, const IntT* cols_data; int64_t batch_nnz; int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(), + batch_size * (M + 1), + phi::errors::PreconditionNotMet( + "the length of SparseCsrTensor crows is not right.")); if (!is_out) { - for (int i = 0; i < x_ndims - 2; i++) { - batch_size *= xdim_vec[i]; - } - PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(), - batch_size * (M + 1), - phi::errors::PreconditionNotMet( - "the length of SparseCsrTensor crows is not right.")); - crows_data = x.non_zero_crows().data(); cols_data = x.non_zero_cols().data(); values_data = x.non_zero_elements().data(); @@ -262,6 +261,7 @@ class CuSparseSpMatDescriptor { })); VLOG(6) << "Create coo cusparseSpMatDescr_t " << &descriptor_; } + explicit CuSparseSpMatDescriptor(const phi::SparseCsrTensor& x, const phi::GPUContext& dev_ctx, bool is_out) @@ -272,6 +272,7 @@ class CuSparseSpMatDescriptor { })); VLOG(6) << "Create csr cusparseSpMatDescr_t for SPGEMM" << &descriptor_; } + ~CuSparseSpMatDescriptor() { dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseDestroySpMat(descriptor_); diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index e64c400375b01a..8d766b22a44c2e 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -154,9 +154,9 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx, auto dims_numel = y.dims().size(); SparseCsrTensor tmp_y; if (dims_numel == 2) { - TransposeCsrIntKernel(dev_ctx, y, {1, 0}, &tmp_y); + TransposeCsrKernel(dev_ctx, y, {1, 0}, &tmp_y); } else { - TransposeCsrIntKernel(dev_ctx, y, {0, 2, 1}, &tmp_y); + TransposeCsrKernel(dev_ctx, y, {0, 2, 1}, &tmp_y); } sparse_blas.SPMM( @@ -168,9 +168,9 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx, auto dims_numel = x.dims().size(); SparseCsrTensor tmp_x; if (dims_numel == 2) { - TransposeCsrIntKernel(dev_ctx, x, {1, 0}, &tmp_x); + TransposeCsrKernel(dev_ctx, x, {1, 0}, &tmp_x); } else { - TransposeCsrIntKernel(dev_ctx, x, {0, 2, 1}, &tmp_x); + TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &tmp_x); } sparse_blas.SPMM( diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index cb2605c94407e5..f5a34cd0656cde 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -202,11 +202,11 @@ void MaskedMatmulCsrKernel(const Context& dev_ctx, #endif } -template -void MatmulKernelImpl(const Context& dev_ctx, - const TensorType& x, - const TensorType& y, - TensorType* out) { +template +void MatmulCsrCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + SparseCsrTensor* out) { #if CUDA_VERSION >= 11000 std::vector xdim_vec = common::vectorize(x.dims()); std::vector ydim_vec = common::vectorize(y.dims()); @@ -257,14 +257,6 @@ void MatmulKernelImpl(const Context& dev_ctx, #endif } -template -void MatmulCsrCsrKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - const SparseCsrTensor& y, - SparseCsrTensor* out) { - MatmulKernelImpl(dev_ctx, x, y, out); -} - template void MatmulCooCooKernel(const Context& dev_ctx, const SparseCooTensor& x, @@ -274,7 +266,7 @@ void MatmulCooCooKernel(const Context& dev_ctx, SparseCsrTensor y_csr = CooToCsr(dev_ctx, y); SparseCsrTensor out_csr; out_csr.set_dims(out->dims()); - MatmulKernelImpl(dev_ctx, x_csr, y_csr, &out_csr); + MatmulCsrCsrKernel(dev_ctx, x_csr, y_csr, &out_csr); CsrToCooKernel(dev_ctx, out_csr, out); } diff --git a/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu b/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu index 1efebd3131f9a6..e1b13227aa54f7 100644 --- a/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/transpose_kernel.cu @@ -17,6 +17,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" @@ -317,15 +318,10 @@ void TransposeCsrKernel(const Context &dev_ctx, const SparseCsrTensor &x, const std::vector &perm, SparseCsrTensor *out) { - TransposeCsrImpl(dev_ctx, x, perm, out); -} - -template -void TransposeCsrIntKernel(const Context &dev_ctx, - const SparseCsrTensor &x, - const std::vector &perm, - SparseCsrTensor *out) { - TransposeCsrImpl(dev_ctx, x, perm, out); + PD_VISIT_BASE_INTEGRAL_TYPES( + x.non_zero_crows().dtype(), "TransposeCsrKernel", ([&] { + TransposeCsrImpl(dev_ctx, x, perm, out); + })); } } // namespace sparse } // namespace phi @@ -357,17 +353,3 @@ PD_REGISTER_KERNEL(transpose_csr, int, int64_t, bool) {} - -PD_REGISTER_KERNEL(transpose_csr_int, - GPU, - ALL_LAYOUT, - phi::sparse::TransposeCsrIntKernel, - phi::dtype::float16, - float, - double, - int8_t, - uint8_t, - int16_t, - int, - int64_t, - bool) {} diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h index 266901637d26f9..dff8742f5afc79 100644 --- a/paddle/phi/kernels/sparse/unary_kernel.h +++ b/paddle/phi/kernels/sparse/unary_kernel.h @@ -126,12 +126,6 @@ void TransposeCsrKernel(const Context& dev_ctx, const std::vector& perm, SparseCsrTensor* out); -template -void TransposeCsrIntKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - const std::vector& perm, - SparseCsrTensor* out); - template SparseCooTensor TransposeCoo(const Context& dev_ctx, const SparseCooTensor& x, diff --git a/test/legacy_test/test_sparse_matmul_op.py b/test/legacy_test/test_sparse_matmul_op.py index ad8e2715635a10..1119d88d08ed5f 100644 --- a/test/legacy_test/test_sparse_matmul_op.py +++ b/test/legacy_test/test_sparse_matmul_op.py @@ -170,9 +170,9 @@ def test_masked_matmul_3d(self): ) -class TestMatmulCSR(unittest.TestCase): - # x: csr sparse, y: csr sparse, out: csr sparse - def check_result(self, x_shape, y_shape): +class TestMatmulSparseSparse(unittest.TestCase): + # x: sparse, y: sparse, out: sparse + def check_result(self, x_shape, y_shape, sparse): mask = paddle.randint(0, 2, x_shape) origin_x = paddle.rand(x_shape) * mask origin_y = paddle.rand(y_shape) @@ -182,22 +182,37 @@ def check_result(self, x_shape, y_shape): dense_y = origin_y.detach() dense_y.stop_gradient = False dense_out = paddle.matmul(dense_x, dense_y) + if sparse == 'csr': + sp_x = origin_x.detach().to_sparse_csr() + # only support 32-bit index. + sp_x_crows = paddle.cast(sp_x.crows(), "int32") + sp_x_cols = paddle.cast(sp_x.cols(), "int32") + sp_x = paddle.sparse.sparse_csr_tensor( + sp_x_crows, sp_x_cols, sp_x.values(), sp_x.shape + ) - sp_x = origin_x.detach().to_sparse_csr() - # only support 32-bit index. - sp_x_crows = paddle.cast(sp_x.crows(), "int32") - sp_x_cols = paddle.cast(sp_x.cols(), "int32") - sp_x = paddle.sparse.sparse_csr_tensor( - sp_x_crows, sp_x_cols, sp_x.values(), sp_x.shape - ) + sp_y = origin_y.detach().to_sparse_csr() + # only support 32-bit index. + sp_y_crows = paddle.cast(sp_y.crows(), "int32") + sp_y_cols = paddle.cast(sp_y.cols(), "int32") + sp_y = paddle.sparse.sparse_csr_tensor( + sp_y_crows, sp_y_cols, sp_y.values(), sp_y.shape + ) + else: + sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) - sp_y = origin_y.detach().to_sparse_csr() - # only support 32-bit index. - sp_y_crows = paddle.cast(sp_y.crows(), "int32") - sp_y_cols = paddle.cast(sp_y.cols(), "int32") - sp_y = paddle.sparse.sparse_csr_tensor( - sp_y_crows, sp_y_cols, sp_y.values(), sp_y.shape - ) + # only support 32-bit index. + sp_x_indices = paddle.cast(sp_x.indices(), "int32") + sp_x = paddle.sparse.sparse_coo_tensor( + sp_x_indices, sp_x.values(), sp_x.shape + ) + + sp_y = origin_y.detach().to_sparse_coo(len(y_shape)) + # only support 32-bit index. + sp_y_indices = paddle.cast(sp_y.indices(), "int32") + sp_y = paddle.sparse.sparse_coo_tensor( + sp_y_indices, sp_y.values(), sp_y.shape + ) sp_x.stop_gradient = False sp_y.stop_gradient = False @@ -207,87 +222,33 @@ def check_result(self, x_shape, y_shape): np.testing.assert_allclose( sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05 ) - if get_cuda_version() >= 11000: - dense_out.backward() - sp_out.backward() - np.testing.assert_allclose( - sp_x.grad.to_dense().numpy(), - dense_x.grad.numpy(), - rtol=1e-05, - ) - np.testing.assert_allclose( - sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05 - ) - @unittest.skipIf( - not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, - "only support cuda>=11.0", - ) - def test_matmul_2d(self): - self.check_result([16, 12], [12, 10]) - - @unittest.skipIf( - not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, - "only support cuda>=11.0", - ) - def test_matmul_3d(self): - self.check_result([2, 16, 12], [2, 12, 10]) - - -class TestMatmulCOO(unittest.TestCase): - # x: coo sparse, y: coo sparse, out: coo sparse - def check_result(self, x_shape, y_shape): - mask = paddle.randint(0, 2, x_shape) - origin_x = paddle.rand(x_shape) * mask - origin_y = paddle.rand(y_shape) - - dense_x = origin_x.detach() - dense_x.stop_gradient = False - dense_y = origin_y.detach() - dense_y.stop_gradient = False - dense_out = paddle.matmul(dense_x, dense_y) - - sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) - - # only support 32-bit index. - sp_x_indices = paddle.cast(sp_x.indices(), "int32") - sp_x = paddle.sparse.sparse_coo_tensor( - sp_x_indices, sp_x.values(), sp_x.shape - ) - - sp_y = origin_y.detach().to_sparse_coo(len(y_shape)) - # only support 32-bit index. - sp_y_indices = paddle.cast(sp_y.indices(), "int32") - sp_y = paddle.sparse.sparse_coo_tensor( - sp_y_indices, sp_y.values(), sp_y.shape + dense_out.backward() + sp_out.backward() + np.testing.assert_allclose( + sp_x.grad.to_dense().numpy(), + dense_x.grad.numpy(), + rtol=1e-05, ) - - sp_x.stop_gradient = False - sp_y.stop_gradient = False - - sp_out = paddle.sparse.matmul(sp_x, sp_y) np.testing.assert_allclose( - sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05 + sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05 ) - if get_cuda_version() >= 11000: - dense_out.backward() - sp_out.backward() - np.testing.assert_allclose( - sp_x.grad.to_dense().numpy(), - dense_x.grad.numpy(), - rtol=1e-05, - ) - np.testing.assert_allclose( - sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05 - ) - @unittest.skipIf( not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, "only support cuda>=11.0", ) def test_matmul_2d(self): - self.check_result([16, 12], [12, 10]) + self.check_result([16, 12], [12, 10], 'csr') + self.check_result([16, 12], [12, 10], 'coo') + + @unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11080, + "only support cuda>=11.8", + ) + def test_matmul_3d(self): + self.check_result([2, 16, 12], [2, 12, 10], 'csr') + self.check_result([2, 16, 12], [2, 12, 10], 'coo') if __name__ == "__main__":