Skip to content

Commit

Permalink
fix windows build failure
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Jan 31, 2024
1 parent b0b6945 commit 384b1d4
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,25 @@ torch::Tensor cutlass_linear_geglu(const torch::Tensor &input,
}

torch::Tensor output;

auto dispatch_bf16 = [&] {
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2)
if (at::globalContext().allowBF16ReductionCuBLAS()) {
output =
CutlassDualGemmLauncher<at::BFloat16, GemmGEGLUWrapper,
cutlass::epilogue::thread::GELU_taylor_fast,
true>::launch(input, weight0, bias0, weight1,
bias1, fallback);
} else
#endif
{
output = CutlassDualGemmLauncher<at::BFloat16, GemmGEGLUWrapper,
cutlass::epilogue::thread::GELU,
false>::launch(input, weight0, bias0,
weight1, bias1, fallback);
}
};
AT_DISPATCH_SWITCH(
input.scalar_type(), "cutlass_linear_geglu",
AT_DISPATCH_CASE(
Expand All @@ -501,24 +520,7 @@ torch::Tensor cutlass_linear_geglu(const torch::Tensor &input,
bias1, fallback);
}
});
AT_DISPATCH_CASE(at::kBFloat16, [&] {
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2)
if (at::globalContext().allowBF16ReductionCuBLAS()) {
output = CutlassDualGemmLauncher<
at::BFloat16, GemmGEGLUWrapper,
cutlass::epilogue::thread::GELU_taylor_fast,
true>::launch(input, weight0, bias0, weight1, bias1, fallback);
} else
#endif
{
output =
CutlassDualGemmLauncher<at::BFloat16, GemmGEGLUWrapper,
cutlass::epilogue::thread::GELU,
false>::launch(input, weight0, bias0,
weight1, bias1, fallback);
}
}));
AT_DISPATCH_CASE(at::kBFloat16, dispatch_bf16));
return output;
}

Expand Down

0 comments on commit 384b1d4

Please sign in to comment.