Skip to content

Commit

Permalink
Merge pull request #7 from mvpatel2000/mvpatel2000/update-seutp
Browse files Browse the repository at this point in the history
Update setup to be more flexible in cuda builds
  • Loading branch information
tgale96 authored Jan 11, 2024
2 parents 108009a + 35034ac commit a5e5311
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
if os.environ.get("TORCH_CUDA_ARCH_LIST"):
# Let PyTorch builder to choose device to target for.
device_capability = ""
else:
device_capability = torch.cuda.get_device_capability()
device_capability = f"{device_capability[0]}{device_capability[1]}"

cwd = Path(os.path.dirname(os.path.abspath(__file__)))
_dc = torch.cuda.get_device_capability()
_dc = f"{_dc[0]}{_dc[1]}"

nvcc_flags = [
"-std=c++17", # NOTE: CUTLASS requires c++17
]

if device_capability:
nvcc_flags.extend([
f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}",
f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}",
])

ext_modules = [
CUDAExtension(
Expand All @@ -24,12 +35,7 @@
"cxx": [
"-fopenmp", "-fPIC", "-Wno-strict-aliasing"
],
"nvcc": [
f"--generate-code=arch=compute_{_dc},code=sm_{_dc}",
f"-DGROUPED_GEMM_DEVICE_CAPABILITY={_dc}",
# NOTE: CUTLASS requires c++17.
"-std=c++17",
],
"nvcc": nvcc_flags,
}
)
]
Expand All @@ -44,7 +50,7 @@

setup(
name="grouped_gemm",
version="0.0.1",
version="0.1.1",
author="Trevor Gale",
author_email="[email protected]",
description="Grouped GEMM",
Expand Down

0 comments on commit a5e5311

Please sign in to comment.