Skip to content

Commit

Permalink
patch changes for cudnn
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper committed Jun 10, 2021
1 parent 325160b commit bf7d797
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
19 changes: 16 additions & 3 deletions cmake/third_party/FindCUDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ include(FindPackageHandleStandardArgs)
set(CUDNN_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA cuDNN")

option(CUDNN_STATIC "Look for static cuDNN" ON)
if(BUILD_SHARED_LIBS)
set(CUDNN_STATIC OFF)
endif()
if (CUDNN_STATIC)
set(__cudnn_libname "libcudnn_static.a")
else()
Expand All @@ -24,6 +27,7 @@ find_path(CUDNN_INCLUDE_DIR cudnn.h
HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES cuda/include include)

unset(CUDNN_LIBRARY CACHE)
find_library(CUDNN_LIBRARY ${__cudnn_libname}
HINTS ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64)
Expand Down Expand Up @@ -58,8 +62,17 @@ if(CUDNN_FOUND)
endif()

set(CUDNN_INCLUDE_DIRS ${CUDNN_INCLUDE_DIR})
set(CUDNN_LIBRARIES ${CUDNN_LIBRARY})
message(STATUS "Found cuDNN: v${CUDNN_VERSION} (include: ${CUDNN_INCLUDE_DIR}, library: ${CUDNN_LIBRARY})")

if(NOT CUDNN_STATIC AND CUDNN_VERSION_MAJOR GREATER_EQUAL 8)
# skipping: libcudnn_adv_infer.so libcudnn_adv_train.so
set(CUDNN_DYNAMIC_NAMES libcudnn_cnn_infer.so libcudnn_cnn_train.so libcudnn_ops_infer.so libcudnn_ops_train.so)
get_filename_component(CUDNN_LIBRARY_DIRECTORY ${CUDNN_LIBRARY} DIRECTORY)
foreach(CUDNN_DYNAMIC_NAME ${CUDNN_DYNAMIC_NAMES})
list(APPEND CUDNN_LIBRARIES ${CUDNN_LIBRARY_DIRECTORY}/${CUDNN_DYNAMIC_NAME})
endforeach()
else()
set(CUDNN_LIBRARIES ${CUDNN_LIBRARY})
endif()
message(STATUS "Found cuDNN: v${CUDNN_VERSION} (include: ${CUDNN_INCLUDE_DIR}, library: ${CUDNN_LIBRARIES})")
mark_as_advanced(CUDNN_ROOT_DIR CUDNN_LIBRARY CUDNN_INCLUDE_DIR)
endif()

20 changes: 17 additions & 3 deletions docker/package/manylinux/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def build_img(
cuda_version_img = cuda_version
if cuda_version == "11.2":
cuda_version_img = "11.2.2"
if cuda_version == "11.1":
cuda_version_img = "11.1.1"
if cuda_version == "11.0":
cuda_version_img = "11.0.3"
from_img = f"nvidia/cuda:{cuda_version_img}-cudnn{cudnn_version}-devel-centos7"
tuna_build_arg = ""
if use_tuna:
Expand Down Expand Up @@ -374,6 +378,14 @@ def build():
img_prefix = f"oneflow-manylinux2014-cuda{cuda_version}"
user = getpass.getuser()
versioned_img_tag = f"{img_prefix}:0.1"
if cuda_version in ["11.0", "11.1"]:
versioned_img_tag = f"{img_prefix}:0.2"
enforced_oneflow_cmake_args = ""
if float(cuda_version) >= 11:
assert (
"CUDNN_STATIC" not in extra_oneflow_cmake_args
), "CUDNN_STATIC will be set to OFF if cuda_version > 11"
enforced_oneflow_cmake_args += " -DCUDNN_STATIC=OFF"
user_img_tag = f"{img_prefix}:{user}"
extra_docker_args = args.extra_docker_args
if "--name" not in extra_docker_args:
Expand All @@ -384,7 +396,9 @@ def build():
img_tag = args.custom_img_tag
skip_img = True
elif skip_img:
assert is_img_existing(versioned_img_tag)
assert is_img_existing(
versioned_img_tag
), f"img not found: {versioned_img_tag}"
img_tag = versioned_img_tag
else:
img_tag = user_img_tag
Expand Down Expand Up @@ -437,7 +451,7 @@ def build():
img_tag,
args.oneflow_src_dir,
cache_dir,
extra_oneflow_cmake_args,
extra_oneflow_cmake_args + enforced_oneflow_cmake_args,
extra_docker_args,
bash_args,
bash_wrap,
Expand All @@ -455,7 +469,7 @@ def build():
img_tag,
args.oneflow_src_dir,
cache_dir,
extra_oneflow_cmake_args,
extra_oneflow_cmake_args + enforced_oneflow_cmake_args,
extra_docker_args,
python_version,
args.skip_wheel,
Expand Down

0 comments on commit bf7d797

Please sign in to comment.