Skip to content

Commit

Permalink
fix new
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Dec 21, 2023
1 parent 767ede1 commit 9fc182f
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 209 deletions.
93 changes: 19 additions & 74 deletions paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x,

int64_t batch_nnz = x.nnz() / batch_size;
cudaDataType_t gpu_type = GetGpuDataType<T>();
cusparseIndexType_t index_type =
std::is_same<T, int32_t>::value ? CUSPARSE_INDEX_32I : CUSPARSE_INDEX_64I;
dev_ctx.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseCreateCsr(descriptor,
M,
Expand All @@ -98,8 +100,8 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x,
const_cast<IntT*>(crows_data),
const_cast<IntT*>(cols_data),
const_cast<T*>(values_data),
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_32I,
index_type,
index_type,
CUSPARSE_INDEX_BASE_ZERO,
gpu_type);
});
Expand All @@ -117,36 +119,6 @@ inline void CreateCsrDescriptor(const phi::SparseCsrTensor& x,
}
}

template <typename T>
inline void CreateOutCsrDescriptor(const phi::SparseCsrTensor& x,
const phi::GPUContext& dev_ctx,
cusparseSpMatDescr_t* descriptor) {
std::vector<int64_t> xdim_vec = common::vectorize(x.dims());
auto x_ndims = xdim_vec.size();
PADDLE_ENFORCE_GE(
x_ndims,
2,
phi::errors::InvalidArgument("the dim size of SparseCsrTensor must be "
"greater than or eaqual to 2."));
int64_t M = xdim_vec[x_ndims - 2];
int64_t N = xdim_vec[x_ndims - 1];

cudaDataType_t gpu_type = GetGpuDataType<T>();
dev_ctx.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseCreateCsr(descriptor,
M,
N,
0,
nullptr,
nullptr,
nullptr,
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO,
gpu_type);
});
}

template <typename T, typename IntT>
inline void CreateCooDescriptor(const phi::SparseCooTensor& x,
const phi::GPUContext& dev_ctx,
Expand Down Expand Up @@ -201,30 +173,6 @@ inline void CreateCooDescriptor(const phi::SparseCooTensor& x,
}
}

template <typename T>
class CuSparseOutSpMatDescriptor {
public:
explicit CuSparseOutSpMatDescriptor(const phi::SparseCsrTensor& x,
const phi::GPUContext& dev_ctx)
: dev_ctx_(dev_ctx) {
CreateOutCsrDescriptor<T>(x, dev_ctx_, &descriptor_);
VLOG(6) << "Create csr cusparseSpMatDescr_t " << &descriptor_;
}

~CuSparseOutSpMatDescriptor() {
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDestroySpMat(descriptor_);
});
VLOG(6) << "Destroy cusparseSpMatDescr_t " << &descriptor_;
}

const cusparseSpMatDescr_t& descriptor() const { return descriptor_; }

private:
const phi::GPUContext& dev_ctx_;
cusparseSpMatDescr_t descriptor_;
};

template <typename T>
class CuSparseSpMatDescriptor {
public:
Expand Down Expand Up @@ -505,9 +453,19 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
const TensorType& mat_b,
T beta,
TensorType* mat_out) const {
auto dims = mat_out->dims();
DenseTensor* mat_out_crows = mat_out->mutable_crows();
MetaTensor meta_out_crows(mat_out_crows);
meta_out_crows.set_dtype(mat_a.crows().dtype());
meta_out_crows.set_dims(common::make_ddim({dims[dims.size() - 2] + 1}));
int* out_crows = dev_ctx_.template Alloc<int>(mat_out_crows);
DenseTensor* mat_out_cols = mat_out->mutable_cols();
MetaTensor meta_out_cols(mat_out_cols);
meta_out_cols.set_dtype(mat_a.cols().dtype());

auto a_descriptor = CuSparseSpMatDescriptor<T>(mat_a, dev_ctx_);
auto b_descriptor = CuSparseSpMatDescriptor<T>(mat_b, dev_ctx_);
auto out_descriptor = CuSparseOutSpMatDescriptor<T>(*mat_out, dev_ctx_);
auto out_descriptor = CuSparseSpMatDescriptor<T>(*mat_out, dev_ctx_);
auto spgemm_descriptor = CuSparseSpGEMMDescriptor<T>(dev_ctx_);

cudaDataType_t gpu_type = GetGpuDataType<T>();
Expand Down Expand Up @@ -593,29 +551,16 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
});

DenseTensor* mat_out_values = mat_out->mutable_values();
DenseTensor* mat_out_crows = mat_out->mutable_crows();
DenseTensor* mat_out_cols = mat_out->mutable_cols();
MetaTensor meta_out_values(mat_out_values);
MetaTensor meta_out_crows(mat_out_crows);
MetaTensor meta_out_cols(mat_out_cols);
meta_out_crows.set_dtype(mat_a.crows().dtype());
meta_out_cols.set_dtype(mat_a.cols().dtype());
meta_out_crows.set_dims(common::make_ddim({C_num_rows1 + 1}));
meta_out_cols.set_dims(common::make_ddim({C_nnz1}));
meta_out_values.set_dtype(mat_a.values().dtype());
meta_out_values.set_dims(common::make_ddim({C_nnz1}));
dev_ctx_.template Alloc<T>(mat_out_values);
dev_ctx_.template Alloc<int>(mat_out_cols);
dev_ctx_.template Alloc<int>(mat_out_crows);
T* out_values = dev_ctx_.template Alloc<T>(mat_out_values);
int* out_cols = dev_ctx_.template Alloc<int>(mat_out_cols);

T* out_values = mat_out_values->data<T>();
int* out_crows = mat_out_crows->data<int>();
int* out_cols = mat_out_cols->data<int>();
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseCsrSetPointers(out_descriptor.descriptor(),
const_cast<int*>(out_crows),
const_cast<int*>(out_cols),
const_cast<T*>(out_values));
phi::dynload::cusparseCsrSetPointers(
out_descriptor.descriptor(), out_crows, out_cols, out_values);
});

dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
Expand Down
72 changes: 59 additions & 13 deletions paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,37 +152,73 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
// dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr}
if (dx) {
auto dims_numel = y.dims().size();
SparseCsrTensor tmp_y;
SparseCsrTensor transpose_y, tmp_dout, tmp_y;
if (dims_numel == 2) {
tmp_y = TransposeCsr<T, Context>(dev_ctx, y, {1, 0});
TransposeCsrKernel<T, Context>(dev_ctx, y, {1, 0}, &transpose_y);
} else {
tmp_y = TransposeCsr<T, Context>(dev_ctx, y, {0, 2, 1});
TransposeCsrKernel<T, Context>(dev_ctx, y, {0, 2, 1}, &transpose_y);
}

sparse_blas.SPMM(
false, false, static_cast<T>(1), dout, tmp_y, static_cast<T>(0), dx);
CastCsrKernel<T, Context>(
dev_ctx, dout, phi::DATATYPE::INT32, dout.values().dtype(), &tmp_dout);
CastCsrKernel<T, Context>(
dev_ctx, transpose_y, phi::DATATYPE::INT32, y.values().dtype(), &tmp_y);

sparse_blas.SPMM(false,
false,
static_cast<T>(1),
tmp_dout,
tmp_y,
static_cast<T>(0),
dx);
}

// dy{SparseCsr} = x'{SparseCsr} * dout{SparseCsr}
if (dy) {
auto dims_numel = x.dims().size();
SparseCsrTensor tmp_x;
SparseCsrTensor transpose_x, tmp_dout, tmp_x;
if (dims_numel == 2) {
tmp_x = TransposeCsr<T, Context>(dev_ctx, x, {1, 0});
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0}, &transpose_x);
} else {
tmp_x = TransposeCsr<T, Context>(dev_ctx, x, {0, 2, 1});
TransposeCsrKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &transpose_x);
}

sparse_blas.SPMM(
false, false, static_cast<T>(1), tmp_x, dout, static_cast<T>(0), dy);
CastCsrKernel<T, Context>(
dev_ctx, dout, phi::DATATYPE::INT32, dout.values().dtype(), &tmp_dout);
CastCsrKernel<T, Context>(
dev_ctx, transpose_x, phi::DATATYPE::INT32, x.values().dtype(), &tmp_x);
sparse_blas.SPMM(false,
false,
static_cast<T>(1),
tmp_x,
tmp_dout,
static_cast<T>(0),
dy);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"backward of 'sparse.matmul' use cusparseSPGEMM, which is supported from "
"CUDA 11.3"));
"CUDA 11.0"));
#endif
}

template <typename T, typename Context>
void MatmulCooCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& y,
const SparseCooTensor& dout,
SparseCooTensor* dx,
SparseCooTensor* dy) {
// 'cusparseSPGEMM' only support CSR now, so use COO->CSR->COO,
SparseCsrTensor x_csr = CooToCsr<T, Context>(dev_ctx, x);
SparseCsrTensor y_csr = CooToCsr<T, Context>(dev_ctx, y);
SparseCsrTensor dout_csr = CooToCsr<T, Context>(dev_ctx, dout);
SparseCsrTensor dx_csr, dy_csr;
dx_csr.set_dims(dx->dims());
dy_csr.set_dims(dy->dims());
MatmulCsrCsrGradKernel<T>(dev_ctx, x_csr, y_csr, dout_csr, &dx_csr, &dy_csr);
CsrToCooKernel<T>(dev_ctx, dx_csr, dx);
CsrToCooKernel<T>(dev_ctx, dy_csr, dy);
}

template <typename T, typename Context>
void MaskedMatmulCsrGradKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -271,3 +307,13 @@ PD_REGISTER_KERNEL(matmul_csr_csr_grad,
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

PD_REGISTER_KERNEL(matmul_coo_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::MatmulCooCooGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
66 changes: 34 additions & 32 deletions paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

namespace phi {
namespace sparse {
Expand Down Expand Up @@ -201,11 +202,11 @@ void MaskedMatmulCsrKernel(const Context& dev_ctx,
#endif
}

template <typename T, typename Context, typename TensorType>
void MatmulKernelImpl(const Context& dev_ctx,
const TensorType& x,
const TensorType& y,
TensorType* out) {
template <typename T, typename Context>
void MatmulCsrCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& y,
SparseCsrTensor* out) {
#if CUDA_VERSION >= 11000
std::vector<int64_t> xdim_vec = common::vectorize(x.dims());
std::vector<int64_t> ydim_vec = common::vectorize(y.dims());
Expand Down Expand Up @@ -240,39 +241,40 @@ void MatmulKernelImpl(const Context& dev_ctx,
"The shape of Input(x) and Input(y) is not suitable for matmul "
"opetation, x_dim[-1] must be eaqual to y_dim[-2]."));

// InferMeta of DenseTensor 'out'
std::vector<int64_t> out_dim_vec(ydim_vec);
out_dim_vec[y_ndims - 2] = xdim_vec[x_ndims - 2];
out_dim_vec[y_ndims - 1] = ydim_vec[y_ndims - 1];

out->set_dims(common::make_ddim(out_dim_vec));
SparseCsrTensor x_tmp, y_tmp;
CastCsrKernel<T, Context>(
dev_ctx, x, phi::DATATYPE::INT32, x.values().dtype(), &x_tmp);
CastCsrKernel<T, Context>(
dev_ctx, y, phi::DATATYPE::INT32, y.values().dtype(), &y_tmp);

auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SPMM(
false, false, static_cast<T>(1), x, y, static_cast<T>(0), out);
false, false, static_cast<T>(1), x_tmp, y_tmp, static_cast<T>(0), out);
#else
PADDLE_THROW(
phi::errors::Unimplemented("forward of 'sparse.matmul' use cusparseSpMM, "
"which is supported from CUDA 11.0"));
PADDLE_THROW(phi::errors::Unimplemented(
"forward of 'sparse.matmul' use cusparseSpGEMM, "
"which is supported from CUDA 11.0"));
#endif
}

template <typename T, typename Context>
void MatmulCsrCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const SparseCsrTensor& y,
SparseCsrTensor* out) {
MatmulKernelImpl<T>(dev_ctx, x, y, out);
void MatmulCooCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const SparseCooTensor& y,
SparseCooTensor* out) {
SparseCsrTensor x_csr = CooToCsr<T, Context>(dev_ctx, x);
SparseCsrTensor y_csr = CooToCsr<T, Context>(dev_ctx, y);
SparseCsrTensor out_csr;
out_csr.set_dims(out->dims());
MatmulCsrCsrKernel<T, Context>(dev_ctx, x_csr, y_csr, &out_csr);
CsrToCooKernel<T>(dev_ctx, out_csr, out);
}

// template <typename T, typename Context>
// void MatmulCooCooKernel(const Context& dev_ctx,
// const SparseCooTensor& x,
// const SparseCooTensor& y,
// SparseCooTensor* out) {
// MatmulKernelImpl<T>(dev_ctx, x, y, out);
// }

} // namespace sparse
} // namespace phi

Expand Down Expand Up @@ -311,12 +313,12 @@ PD_REGISTER_KERNEL(matmul_csr_csr,
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}

// PD_REGISTER_KERNEL(matmul_coo_coo,
// GPU,
// ALL_LAYOUT,
// phi::sparse::MatmulCooCooKernel,
// float,
// double) {
// kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
// kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
// }
PD_REGISTER_KERNEL(matmul_coo_coo,
GPU,
ALL_LAYOUT,
phi::sparse::MatmulCooCooKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
Loading

0 comments on commit 9fc182f

Please sign in to comment.