Skip to content

Commit

Permalink
Add W2A16 kernel to TritonBench
Browse files Browse the repository at this point in the history
Summary:
1. added int2 to int8 pack Triton kernel
2. added w2a16 GEMM based on the cutlass mixed precision GEMM backbone
3. added to TritonBench with the comparison of w4a16 of marlin, machete, tinnyGEMM

Reviewed By: chenyang78

Differential Revision: D68024831

fbshipit-source-id: 2caf33b9bec7d14db582176f358677566e270127
  • Loading branch information
sijiac authored and facebook-github-bot committed Jan 14, 2025
1 parent d83d5f0 commit a1193d9
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 0 deletions.
1 change: 1 addition & 0 deletions tritonbench/operators/mixed_gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
Empty file.
102 changes: 102 additions & 0 deletions tritonbench/operators/mixed_gemm/kernels/helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include <iostream>
#include "cuda_runtime.h"

/**
* Panic wrapper for unwinding CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \
<< " at: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}

/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}

/**
* GPU timer for recording the elapsed time across kernel(s) launched in GPU
* stream
*/
struct GpuTimer {
cudaStream_t _stream_id;
cudaEvent_t _start;
cudaEvent_t _stop;

/// Constructor
GpuTimer() : _stream_id(0) {
CUDA_CHECK(cudaEventCreate(&_start));
CUDA_CHECK(cudaEventCreate(&_stop));
}

/// Destructor
~GpuTimer() {
CUDA_CHECK(cudaEventDestroy(_start));
CUDA_CHECK(cudaEventDestroy(_stop));
}

/// Start the timer for a given stream (defaults to the default stream)
void start(cudaStream_t stream_id = 0) {
_stream_id = stream_id;
CUDA_CHECK(cudaEventRecord(_start, _stream_id));
}

/// Stop the timer
void stop() {
CUDA_CHECK(cudaEventRecord(_stop, _stream_id));
}

/// Return the elapsed time (in milliseconds)
float elapsed_millis() {
float elapsed = 0.0;
CUDA_CHECK(cudaEventSynchronize(_stop));
CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop));
return elapsed;
}
};
57 changes: 57 additions & 0 deletions tritonbench/operators/mixed_gemm/kernels/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import torch

# Load the custom op library
torch.ops.load_library(
"//pytorch/tritonbench/tritonbench/operators/mixed_gemm/kernels:w2a16_gemm_lib"
)

from .quantize import dequantize_int2_to_bf16, quantize_bf16_to_int2


def main():
M = 1024
N = 8192
K = 8192

dtype = torch.bfloat16

X = torch.randn([M, K], dtype=dtype, device="cuda")
W = torch.tensor(
[[-2.0, -1.0, 0, 1.0] * (K // 4) for _ in range(N)],
dtype=dtype,
device="cuda",
)

WQ = quantize_bf16_to_int2(W)

out = torch.ops.mixed_gemm.w2a16_gemm(X, WQ)

WQ = WQ.transpose(0, 1).contiguous().transpose(0, 1)

print(X.shape, X.dtype, X.stride())
print(WQ.shape, WQ.dtype, WQ.stride())
print(out.shape, out.dtype, out.stride())

W_dequant = dequantize_int2_to_bf16(WQ)

out_ref = torch.matmul(X, W)
out_ref_dequant = torch.matmul(X, W_dequant)

print("==== CUTLASS ====")
print(out.shape, out.dtype)
print(out[0])
print(out[1])
print("==== Reference ====")
print(out_ref.shape, out_ref.dtype)
print(out_ref[0])
print(out_ref[1])
print("==== Reference Dequant ====")
print(out_ref_dequant.shape, out_ref_dequant.dtype)
print(out_ref_dequant[0])
print(out_ref_dequant[1])


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions tritonbench/operators/mixed_gemm/kernels/torch_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <ATen/ATen.h>
#include <torch/library.h>

at::Tensor w2a16(const at::Tensor& X, const at::Tensor& WQ);

TORCH_LIBRARY_FRAGMENT(mixed_gemm, m) {
m.def("w2a16_gemm(Tensor X, Tensor WQ) -> Tensor");
}

TORCH_LIBRARY_IMPL(mixed_gemm, CUDA, m) {
m.impl("w2a16_gemm", w2a16);
}
181 changes: 181 additions & 0 deletions tritonbench/operators/mixed_gemm/kernels/w2a16_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
**************************************************************************************************/

#include <iostream>

#include <ATen/ATen.h>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"

#include "helper.h"

using namespace cute;

template <int TB_M, int TB_N, int TB_K, typename INPUT_DTYPE>
at::Tensor w2a16_kernel(
const at::Tensor& X, // FP16/BF16
const at::Tensor& WQ // INT2, packed in INT8
) {
TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
// X: [M, K]
// WQ: [N, K]
int M = X.size(0);
int K = X.size(1);
int N = WQ.size(0);

auto O = at::empty({M, N}, X.options());

/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
using MmaType = INPUT_DTYPE;
using QuantType = cutlass::int2b_t;

// A matrix configuration
using ElementA = MmaType;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

// B matrix configuration
using ElementB = QuantType;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

// Layout transposes
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;

// C/D matrix configuration
using ElementOut = INPUT_DTYPE;
using LayoutOut = cutlass::layout::RowMajor;
constexpr int AlignmentOut = 128 / cutlass::sizeof_bits<ElementOut>::value;

// Kernel configurations
using ElementAccumulator = float;
using ElementCompute = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = Shape<Int<TB_M>, Int<TB_N>, Int<TB_K>>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

// Epilogue
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementCompute,
ElementOut,
typename cutlass::layout::LayoutTranspose<LayoutOut>::type,
AlignmentOut,
ElementOut,
typename cutlass::layout::LayoutTranspose<LayoutOut>::type,
AlignmentOut,
EpilogueSchedule>::CollectiveOp;

// MainLoop for convert-only mode
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementB,
// LayoutB,
LayoutB_Transpose,
AlignmentB,
ElementA,
// LayoutA,
LayoutA_Transpose,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

// Kernel definition
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

// Final GEMM adapter
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// using StrideA = typename GemmKernel::StrideA;
// using StrideB = typename GemmKernel::StrideB;
// Stride definitions
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
using StrideOut = typename GemmKernel::StrideC;

StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
StrideB stride_B =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1));
StrideOut stride_Out =
cutlass::make_cute_packed_stride(StrideOut{}, cute::make_shape(N, M, 1));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K},
{reinterpret_cast<ElementB*>(WQ.data_ptr()),
stride_B,
reinterpret_cast<ElementA*>(X.data_ptr()),
stride_A},
{{},
(ElementOut*)O.data_ptr(),
stride_Out,
(ElementOut*)O.data_ptr(),
stride_Out}};

Gemm gemm;

size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());

return O;
}

template <typename INPUT_DTYPE>
at::Tensor dispatch_w2a16_kernel(const at::Tensor& X, const at::Tensor& WQ) {
// template <int TB_M, int TB_N, int TB_K, typename INPUT_DTYPE>
return w2a16_kernel<128, 128, 128, INPUT_DTYPE>(X, WQ);
}

at::Tensor w2a16(const at::Tensor& X, const at::Tensor& WQ) {
if (X.dtype() == at::kHalf) {
return dispatch_w2a16_kernel<cutlass::half_t>(X, WQ);
} else if (X.dtype() == at::kBFloat16) {
return dispatch_w2a16_kernel<cutlass::bfloat16_t>(X, WQ);
} else {
throw std::runtime_error("DType of the activation (X) is not supported");
}
}
Loading

0 comments on commit a1193d9

Please sign in to comment.