Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Dec 19, 2023
1 parent b09446e commit 8ec9352
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 143 deletions.
17 changes: 9 additions & 8 deletions paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,14 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x,
const IntT* cols_data;
int64_t batch_nnz;
int batch_size = 1;
for (int i = 0; i < x_ndims - 2; i++) {
batch_size *= xdim_vec[i];
}
PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(),
batch_size * (M + 1),
phi::errors::PreconditionNotMet(
"the length of SparseCsrTensor crows is not right."));
if (!is_out) {
for (int i = 0; i < x_ndims - 2; i++) {
batch_size *= xdim_vec[i];
}
PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(),
batch_size * (M + 1),
phi::errors::PreconditionNotMet(
"the length of SparseCsrTensor crows is not right."));

crows_data = x.non_zero_crows().data<IntT>();
cols_data = x.non_zero_cols().data<IntT>();
values_data = x.non_zero_elements().data<T>();
Expand Down Expand Up @@ -262,6 +261,7 @@ class CuSparseSpMatDescriptor {
}));
VLOG(6) << "Create coo cusparseSpMatDescr_t " << &descriptor_;
}

explicit CuSparseSpMatDescriptor(const phi::SparseCsrTensor& x,
const phi::GPUContext& dev_ctx,
bool is_out)
Expand All @@ -272,6 +272,7 @@ class CuSparseSpMatDescriptor {
}));
VLOG(6) << "Create csr cusparseSpMatDescr_t for SPGEMM" << &descriptor_;
}

~CuSparseSpMatDescriptor() {
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDestroySpMat(descriptor_);
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
auto dims_numel = y.dims().size();
SparseCsrTensor tmp_y;
if (dims_numel == 2) {
TransposeCsrIntKernel<T, Context>(dev_ctx, y, {1, 0}, &tmp_y);
TransposeCsrKernel<T, Context>(dev_ctx, y, {1, 0}, &tmp_y);
} else {
TransposeCsrIntKernel<T, Context>(dev_ctx, y, {0, 2, 1}, &tmp_y);
TransposeCsrKernel<T, Context>(dev_ctx, y, {0, 2, 1}, &tmp_y);
}

sparse_blas.SPMM(
Expand All @@ -168,9 +168,9 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
auto dims_numel = x.dims().size();
SparseCsrTensor tmp_x;
if (dims_numel == 2) {
TransposeCsrIntKernel<T, Context>(dev_ctx, x, {1, 0}, &tmp_x);
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0}, &tmp_x);
} else {
TransposeCsrIntKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &tmp_x);
TransposeCsrKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &tmp_x);
}

sparse_blas.SPMM(
Expand Down
20 changes: 6 additions & 14 deletions paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ void MaskedMatmulCsrKernel(const Context& dev_ctx,
#endif
}

template <typename T, typename Context, typename TensorType>
void MatmulKernelImpl(const Context& dev_ctx,
const TensorType& x,
const TensorType& y,
TensorType* out) {
template <typename T, typename Context>
void MatmulCsrCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& y,
SparseCsrTensor* out) {
#if CUDA_VERSION >= 11000
std::vector<int64_t> xdim_vec = common::vectorize(x.dims());
std::vector<int64_t> ydim_vec = common::vectorize(y.dims());
Expand Down Expand Up @@ -257,14 +257,6 @@ void MatmulKernelImpl(const Context& dev_ctx,
#endif
}

template <typename T, typename Context>
void MatmulCsrCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& y,
SparseCsrTensor* out) {
MatmulKernelImpl<T>(dev_ctx, x, y, out);
}

template <typename T, typename Context>
void MatmulCooCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
Expand All @@ -274,7 +266,7 @@ void MatmulCooCooKernel(const Context& dev_ctx,
SparseCsrTensor y_csr = CooToCsr<T, Context>(dev_ctx, y);
SparseCsrTensor out_csr;
out_csr.set_dims(out->dims());
MatmulKernelImpl<T>(dev_ctx, x_csr, y_csr, &out_csr);
MatmulCsrCsrKernel<T, Context>(dev_ctx, x_csr, y_csr, &out_csr);
CsrToCooKernel<T>(dev_ctx, out_csr, out);
}

Expand Down
28 changes: 5 additions & 23 deletions paddle/phi/kernels/sparse/gpu/transpose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
Expand Down Expand Up @@ -317,15 +318,10 @@ void TransposeCsrKernel(const Context &dev_ctx,
const SparseCsrTensor &x,
const std::vector<int> &perm,
SparseCsrTensor *out) {
TransposeCsrImpl<T, int64_t, Context>(dev_ctx, x, perm, out);
}

template <typename T, typename Context>
void TransposeCsrIntKernel(const Context &dev_ctx,
const SparseCsrTensor &x,
const std::vector<int> &perm,
SparseCsrTensor *out) {
TransposeCsrImpl<T, int, Context>(dev_ctx, x, perm, out);
PD_VISIT_BASE_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "TransposeCsrKernel", ([&] {
TransposeCsrImpl<T, data_t, Context>(dev_ctx, x, perm, out);
}));
}
} // namespace sparse
} // namespace phi
Expand Down Expand Up @@ -357,17 +353,3 @@ PD_REGISTER_KERNEL(transpose_csr,
int,
int64_t,
bool) {}

PD_REGISTER_KERNEL(transpose_csr_int,
GPU,
ALL_LAYOUT,
phi::sparse::TransposeCsrIntKernel,
phi::dtype::float16,
float,
double,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
bool) {}
6 changes: 0 additions & 6 deletions paddle/phi/kernels/sparse/unary_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,6 @@ void TransposeCsrKernel(const Context& dev_ctx,
const std::vector<int>& perm,
SparseCsrTensor* out);

template <typename T, typename Context>
void TransposeCsrIntKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const std::vector<int>& perm,
SparseCsrTensor* out);

template <typename T, typename Context>
SparseCooTensor TransposeCoo(const Context& dev_ctx,
const SparseCooTensor& x,
Expand Down
137 changes: 49 additions & 88 deletions test/legacy_test/test_sparse_matmul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def test_masked_matmul_3d(self):
)


class TestMatmulCSR(unittest.TestCase):
# x: csr sparse, y: csr sparse, out: csr sparse
def check_result(self, x_shape, y_shape):
class TestMatmulSparseSparse(unittest.TestCase):
# x: sparse, y: sparse, out: sparse
def check_result(self, x_shape, y_shape, sparse):
mask = paddle.randint(0, 2, x_shape)
origin_x = paddle.rand(x_shape) * mask
origin_y = paddle.rand(y_shape)
Expand All @@ -182,22 +182,37 @@ def check_result(self, x_shape, y_shape):
dense_y = origin_y.detach()
dense_y.stop_gradient = False
dense_out = paddle.matmul(dense_x, dense_y)
if sparse == 'csr':
sp_x = origin_x.detach().to_sparse_csr()
# only support 32-bit index.
sp_x_crows = paddle.cast(sp_x.crows(), "int32")
sp_x_cols = paddle.cast(sp_x.cols(), "int32")
sp_x = paddle.sparse.sparse_csr_tensor(
sp_x_crows, sp_x_cols, sp_x.values(), sp_x.shape
)

sp_x = origin_x.detach().to_sparse_csr()
# only support 32-bit index.
sp_x_crows = paddle.cast(sp_x.crows(), "int32")
sp_x_cols = paddle.cast(sp_x.cols(), "int32")
sp_x = paddle.sparse.sparse_csr_tensor(
sp_x_crows, sp_x_cols, sp_x.values(), sp_x.shape
)
sp_y = origin_y.detach().to_sparse_csr()
# only support 32-bit index.
sp_y_crows = paddle.cast(sp_y.crows(), "int32")
sp_y_cols = paddle.cast(sp_y.cols(), "int32")
sp_y = paddle.sparse.sparse_csr_tensor(
sp_y_crows, sp_y_cols, sp_y.values(), sp_y.shape
)
else:
sp_x = origin_x.detach().to_sparse_coo(len(x_shape))

sp_y = origin_y.detach().to_sparse_csr()
# only support 32-bit index.
sp_y_crows = paddle.cast(sp_y.crows(), "int32")
sp_y_cols = paddle.cast(sp_y.cols(), "int32")
sp_y = paddle.sparse.sparse_csr_tensor(
sp_y_crows, sp_y_cols, sp_y.values(), sp_y.shape
)
# only support 32-bit index.
sp_x_indices = paddle.cast(sp_x.indices(), "int32")
sp_x = paddle.sparse.sparse_coo_tensor(
sp_x_indices, sp_x.values(), sp_x.shape
)

sp_y = origin_y.detach().to_sparse_coo(len(y_shape))
# only support 32-bit index.
sp_y_indices = paddle.cast(sp_y.indices(), "int32")
sp_y = paddle.sparse.sparse_coo_tensor(
sp_y_indices, sp_y.values(), sp_y.shape
)

sp_x.stop_gradient = False
sp_y.stop_gradient = False
Expand All @@ -207,87 +222,33 @@ def check_result(self, x_shape, y_shape):
np.testing.assert_allclose(
sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05
)
if get_cuda_version() >= 11000:
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(
sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy(),
rtol=1e-05,
)
np.testing.assert_allclose(
sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05
)

@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"only support cuda>=11.0",
)
def test_matmul_2d(self):
self.check_result([16, 12], [12, 10])

@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"only support cuda>=11.0",
)
def test_matmul_3d(self):
self.check_result([2, 16, 12], [2, 12, 10])


class TestMatmulCOO(unittest.TestCase):
# x: coo sparse, y: coo sparse, out: coo sparse
def check_result(self, x_shape, y_shape):
mask = paddle.randint(0, 2, x_shape)
origin_x = paddle.rand(x_shape) * mask
origin_y = paddle.rand(y_shape)

dense_x = origin_x.detach()
dense_x.stop_gradient = False
dense_y = origin_y.detach()
dense_y.stop_gradient = False
dense_out = paddle.matmul(dense_x, dense_y)

sp_x = origin_x.detach().to_sparse_coo(len(x_shape))

# only support 32-bit index.
sp_x_indices = paddle.cast(sp_x.indices(), "int32")
sp_x = paddle.sparse.sparse_coo_tensor(
sp_x_indices, sp_x.values(), sp_x.shape
)

sp_y = origin_y.detach().to_sparse_coo(len(y_shape))
# only support 32-bit index.
sp_y_indices = paddle.cast(sp_y.indices(), "int32")
sp_y = paddle.sparse.sparse_coo_tensor(
sp_y_indices, sp_y.values(), sp_y.shape
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(
sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy(),
rtol=1e-05,
)

sp_x.stop_gradient = False
sp_y.stop_gradient = False

sp_out = paddle.sparse.matmul(sp_x, sp_y)
np.testing.assert_allclose(
sp_out.to_dense().numpy(), dense_out.numpy(), rtol=1e-05
sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05
)

if get_cuda_version() >= 11000:
dense_out.backward()
sp_out.backward()
np.testing.assert_allclose(
sp_x.grad.to_dense().numpy(),
dense_x.grad.numpy(),
rtol=1e-05,
)
np.testing.assert_allclose(
sp_y.grad.to_dense().numpy(), dense_y.grad.numpy(), rtol=1e-05
)

@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"only support cuda>=11.0",
)
def test_matmul_2d(self):
self.check_result([16, 12], [12, 10])
self.check_result([16, 12], [12, 10], 'csr')
self.check_result([16, 12], [12, 10], 'coo')

@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11080,
"only support cuda>=11.8",
)
def test_matmul_3d(self):
self.check_result([2, 16, 12], [2, 12, 10], 'csr')
self.check_result([2, 16, 12], [2, 12, 10], 'coo')


if __name__ == "__main__":
Expand Down

0 comments on commit 8ec9352

Please sign in to comment.