Skip to content

Commit

Permalink
Look for nvcc in CUDA_HOME
Browse files Browse the repository at this point in the history
  • Loading branch information
amiller27 committed Aug 26, 2024
1 parent daf9628 commit 3294b15
Showing 1 changed file with 58 additions and 20 deletions.
78 changes: 58 additions & 20 deletions bindings/torch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,64 @@ def find_cl_path():
cpp_standard = 14

# Get CUDA version and make sure the targeted compute capability is compatible
if os.system("nvcc --version") == 0:
nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode()
cuda_version = re.search(r"release (\S+),", nvcc_out)

if cuda_version:
cuda_version = parse_version(cuda_version.group(1))
print(f"Detected CUDA version {cuda_version}")
if cuda_version >= parse_version("11.0"):
cpp_standard = 17

supported_compute_capabilities = [
cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version)
]

if not supported_compute_capabilities:
supported_compute_capabilities = [max_supported_compute_capability(cuda_version)]

if supported_compute_capabilities != compute_capabilities:
print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.")
compute_capabilities = supported_compute_capabilities
def _maybe_find_nvcc():
# Try PATH first
maybe_nvcc = shutil.which("nvcc")

if maybe_nvcc is not None:
return maybe_nvcc

# Then try CUDA_HOME from torch (cpp_extension.CUDA_HOME is undocumented, which is why we only use
# it as a fallback)
try:
from torch.utils.cpp_extension import CUDA_HOME
except ImportError:
return None

if not CUDA_HOME:
return None

return os.path.join(CUDA_HOME, "bin", "nvcc")

def _maybe_nvcc_version():
maybe_nvcc = _maybe_find_nvcc()

if maybe_nvcc is None:
return None

nvcc_version_result = subprocess.run(
[maybe_nvcc, "--version"],
text=True,
check=False,
stdout=subprocess.PIPE,
)

if nvcc_version_result.returncode != 0:
return None

cuda_version = re.search(r"release (\S+),", nvcc_version_result.stdout)

if not cuda_version:
return None

return parse_version(cuda_version.group(1))

cuda_version = _maybe_nvcc_version()
if cuda_version is not None:
print(f"Detected CUDA version {cuda_version}")
if cuda_version >= parse_version("11.0"):
cpp_standard = 17

supported_compute_capabilities = [
cc for cc in compute_capabilities if cc >= min_supported_compute_capability(cuda_version) and cc <= max_supported_compute_capability(cuda_version)
]

if not supported_compute_capabilities:
supported_compute_capabilities = [max_supported_compute_capability(cuda_version)]

if supported_compute_capabilities != compute_capabilities:
print(f"WARNING: Compute capabilities {compute_capabilities} are not all supported by the installed CUDA version {cuda_version}. Targeting {supported_compute_capabilities} instead.")
compute_capabilities = supported_compute_capabilities

min_compute_capability = min(compute_capabilities)

Expand Down

0 comments on commit 3294b15

Please sign in to comment.