From bc764d388f213a631db934029af6463f19cacd5e Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 7 Jul 2024 14:54:40 -0500 Subject: [PATCH] feat(sys): CUDA 12 + cuDNN 8 builds, ref #235 --- ort-sys/build.rs | 16 ++++++++++++++-- ort-sys/dist.txt | 5 +++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 465ed2a..a7bbcf0 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -328,6 +328,18 @@ fn prepare_libort_dir() -> (PathBuf, bool) { feature_set.push("train"); } if cfg!(any(feature = "cuda", feature = "tensorrt")) { + // pytorch's CUDA docker images set `NV_CUDNN_VERSION` + let cu12_tag = match env::var("NV_CUDNN_VERSION").or_else(|_| env::var("ORT_CUDNN_VERSION")).as_deref() { + Ok(v) => { + if v.starts_with("8") { + "cu12+cudnn8" + } else { + "cu12" + } + } + Err(_) => "cu12" + }; + match env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() { Ok("11") => feature_set.push("cu11"), Ok("12") => feature_set.push("cu12"), @@ -340,7 +352,7 @@ fn prepare_libort_dir() -> (PathBuf, bool) { let release_section = version_line.split(", ").nth(1).unwrap(); let version_number = release_section.split(' ').nth(1).unwrap(); if version_number.starts_with("12") { - feature_set.push("cu12"); + feature_set.push(cu12_tag); } else { feature_set.push("cu11"); } @@ -351,7 +363,7 @@ fn prepare_libort_dir() -> (PathBuf, bool) { if !success { println!("cargo:warning=nvcc call did not succeed. falling back to CUDA 12"); // fallback to CUDA 12. - feature_set.push("cu12"); + feature_set.push(cu12_tag); } } } diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 98e3f3f..9910a5e 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -27,3 +27,8 @@ train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/mso train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-apple-darwin.tgz 898EC9E3F852843ECDB618CF8E317F4C92BDEB33FC773038960857BCB37CB347 none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF + +train,cu12+cudnn8 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12+cudnn8-v1.18.1-x86_64-pc-windows-msvc.tgz 52F02DBF276409DC49533373DE89B17FDE0CCB31F9974CCF31F250DC51258971 +train,cu12+cudnn8 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12+cudnn8-v1.18.1-x86_64-unknown-linux-gnu.tgz EE0580CA961CE512ECF7C1087FB081E74C780A494EAC95596CEF1089AB573242 +cu12+cudnn8 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12+cudnn8-v1.18.1-x86_64-unknown-linux-gnu.tgz F8D72E825F744A7A7BF2036591CBE6D1F30352DBE108BEEAD1745BD571566819 +cu12+cudnn8 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12+cudnn8-v1.18.1-x86_64-pc-windows-msvc.tgz D41121A6489B52EB7AF9614D7924AE984F4F10BF49F15E0C4FC2655649A978ED