Skip to content

Commit

Permalink
Merge pull request #369 from vzhurba01/backport-363-to-11.8.x
Browse files Browse the repository at this point in the history
Backport 363 to 11.8.x
  • Loading branch information
vzhurba01 authored Jan 9, 2025
2 parents ca9e641 + 9f13a17 commit a20f0f4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 17 deletions.
66 changes: 60 additions & 6 deletions cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
# this software and related documentation outside the terms of the EULA
# is strictly prohibited.
{{if 'Windows' == platform.system()}}
import win32api
import os
import site
import struct
import win32api
from pywintypes import error
{{else}}
cimport cuda.bindings._lib.dlfcn as dlfcn
Expand Down Expand Up @@ -40,18 +42,70 @@ cdef int cuPythonInit() except -1 nogil:

# Load library
{{if 'Windows' == platform.system()}}
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
with gil:
# First check if the DLL has been loaded by 3rd parties
try:
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
handle = win32api.GetModuleHandle("nvrtc64_112_0.dll")
except:
try:
handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
handle = win32api.GetModuleHandle("nvrtc64_111_0.dll")
except:
try:
handle = win32api.GetModuleHandle("nvrtc64_110_0.dll")
except:
handle = None

# Else try default search
if not handle:
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
try:
handle = win32api.LoadLibraryEx("nvrtc64_112_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
except:
try:
handle = win32api.LoadLibraryEx("nvrtc64_111_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
except:
try:
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
except:
pass

# Final check if DLLs can be found within pip installations
if not handle:
site_packages = [site.getusersitepackages()] + site.getsitepackages()
for sp in site_packages:
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
if not os.path.isdir(mod_path):
continue
os.add_dll_directory(mod_path)
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
try:
handle = win32api.LoadLibraryEx(
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
os.path.join(mod_path, "nvrtc64_112_0.dll"),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)

# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
# located in the same mod_path.
# Update PATH environ so that the two dlls can find each other
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
except:
try:
handle = win32api.LoadLibraryEx("nvrtc64_110_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
handle = win32api.LoadLibraryEx(
os.path.join(mod_path, "nvrtc64_111_0.dll"),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
except:
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
try:
handle = win32api.LoadLibraryEx(
os.path.join(mod_path, "nvrtc64_110_0.dll"),
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
except:
pass

if not handle:
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_112_0.dll, or nvrtc64_111_0.dll, or nvrtc64_110_0.dll')
{{else}}
handle = NULL
if handle == NULL:
Expand Down
5 changes: 5 additions & 0 deletions cuda_bindings/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ dependencies = [
"pywin32; sys_platform == 'win32'",
]

[project.optional-dependencies]
all = [
"nvidia-cuda-nvrtc-cu11"
]

[project.urls]
Repository = "https://github.com/NVIDIA/cuda-python"
Documentation = "https://nvidia.github.io/cuda-python/"
Expand Down
51 changes: 40 additions & 11 deletions cuda_bindings/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
from pyclibrary import CParser
from setuptools import find_packages, setup
from setuptools.extension import Extension
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext
import versioneer


# ----------------------------------------------------------------------
# Fetch configuration options

CUDA_HOME = os.environ.get("CUDA_HOME")
if not CUDA_HOME:
CUDA_HOME = os.environ.get("CUDA_PATH")
CUDA_HOME = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None))
if not CUDA_HOME:
raise RuntimeError('Environment variable CUDA_HOME or CUDA_PATH is not set')

Expand Down Expand Up @@ -236,20 +235,50 @@ def do_cythonize(extensions):
extensions += prep_extensions(sources)

# ---------------------------------------------------------------------
# Custom build_ext command
# Files are build in two steps:
# 1) Cythonized (in the do_cythonize() command)
# 2) Compiled to .o files as part of build_ext
# This class is solely for passing the value of nthreads to build_ext
# Custom cmdclass extensions

building_wheel = False


class WheelsBuildExtensions(bdist_wheel):
def run(self):
global building_wheel
building_wheel = True
super().run()


class ParallelBuildExtensions(build_ext):
def initialize_options(self):
build_ext.initialize_options(self)
super().initialize_options()
if nthreads > 0:
self.parallel = nthreads

def finalize_options(self):
build_ext.finalize_options(self)
def build_extension(self, ext):
if building_wheel and sys.platform == "linux":
# Strip binaries to remove debug symbols
extra_linker_flags = ["-Wl,--strip-all"]

# Allow extensions to discover libraries at runtime
# relative their wheels installation.
if ext.name == "cuda.bindings._bindings.cynvrtc":
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
else:
ldflag = None

if ldflag:
extra_linker_flags.append(ldflag)
else:
extra_linker_flags = []

ext.extra_link_args += extra_linker_flags
super().build_extension(ext)


cmdclass = {
"bdist_wheel": WheelsBuildExtensions,
"build_ext": ParallelBuildExtensions,
}


cmdclass = {"build_ext": ParallelBuildExtensions}
cmdclass = versioneer.get_cmdclass(cmdclass)
Expand Down

0 comments on commit a20f0f4

Please sign in to comment.