diff --git a/setup.py b/setup.py index 9b3833b..d3299eb 100644 --- a/setup.py +++ b/setup.py @@ -4,13 +4,24 @@ import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension -if not torch.cuda.is_available(): - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" +if os.environ.get("TORCH_CUDA_ARCH_LIST"): + # Let PyTorch builder to choose device to target for. + device_capability = "" +else: + device_capability = torch.cuda.get_device_capability() + device_capability = f"{device_capability[0]}{device_capability[1]}" cwd = Path(os.path.dirname(os.path.abspath(__file__))) -_dc = torch.cuda.get_device_capability() -_dc = f"{_dc[0]}{_dc[1]}" + +nvcc_flags = [ + "-std=c++17", # NOTE: CUTLASS requires c++17 +] + +if device_capability: + nvcc_flags.extend([ + f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", + f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", + ]) ext_modules = [ CUDAExtension( @@ -24,12 +35,7 @@ "cxx": [ "-fopenmp", "-fPIC", "-Wno-strict-aliasing" ], - "nvcc": [ - f"--generate-code=arch=compute_{_dc},code=sm_{_dc}", - f"-DGROUPED_GEMM_DEVICE_CAPABILITY={_dc}", - # NOTE: CUTLASS requires c++17. - "-std=c++17", - ], + "nvcc": nvcc_flags, } ) ] @@ -44,7 +50,7 @@ setup( name="grouped_gemm", - version="0.0.1", + version="0.1.1", author="Trevor Gale", author_email="tgale@stanford.edu", description="Grouped GEMM",