Skip to content

Commit

Permalink
Using arch option in nvrtc (#675)
Browse files Browse the repository at this point in the history
* Using arch in nvrtc

* Fixing unused message for cudnn

* Fixing env var when ci-check is active
  • Loading branch information
coreylowman authored Apr 8, 2023
1 parent 64a60e2 commit ee64526
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
16 changes: 11 additions & 5 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ mod cuda {
.collect::<Vec<_>>();

#[cfg(feature = "ci-check")]
for mut kernel_path in kernel_paths.into_iter() {
kernel_path.set_extension("ptx");
{
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=ci");

for mut kernel_path in kernel_paths.into_iter() {
kernel_path.set_extension("ptx");

let mut ptx_path: std::path::PathBuf = out_dir.clone().into();
ptx_path.push(kernel_path.as_path().file_name().unwrap());
std::fs::File::create(ptx_path).unwrap();
let mut ptx_path: std::path::PathBuf = out_dir.clone().into();
ptx_path.push(kernel_path.as_path().file_name().unwrap());
std::fs::File::create(ptx_path).unwrap();
}
}

#[cfg(not(feature = "ci-check"))]
Expand All @@ -76,6 +80,8 @@ mod cuda {
lines.next().unwrap().replace('.', "")
};

println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");

kernel_paths
.iter()
.for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));
Expand Down
1 change: 1 addition & 0 deletions src/tensor/cuda/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct Cuda {
pub(crate) dev: Arc<CudaDevice>,
pub(crate) blas: Arc<CudaBlas>,
#[cfg(feature = "cudnn")]
#[allow(unused)]
pub(crate) cudnn: Arc<cudarc::cudnn::Cudnn>,
/// A second stream for kernels to optionally execute on.
pub(crate) par_stream: Arc<CudaStream>,
Expand Down
14 changes: 11 additions & 3 deletions src/tensor_ops/reshape_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use cudarc::{
driver::{DeviceSlice, LaunchAsync},
nvrtc::compile_ptx,
nvrtc::{compile_ptx_with_opts, CompileOptions},
types::CudaTypeName,
};

Expand All @@ -17,7 +17,11 @@ impl<E: Dtype + CudaTypeName> super::ReshapeKernel<E> for Cuda {
let module = std::format!("reshape_fwd_{}", E::NAME);
if !self.dev.has_func(&module, "reshape_fwd") {
let src = FWD_KERNEL.replace("$T", E::NAME);
let ptx = compile_ptx(src).unwrap();
let opts = CompileOptions {
arch: Some(env!("CUDA_COMPUTE_CAP")),
..Default::default()
};
let ptx = compile_ptx_with_opts(src, opts).unwrap();
self.dev.load_ptx(ptx, &module, &["reshape_fwd"])?;
}
let fwd_fn = self.dev.get_func(&module, "reshape_fwd").unwrap();
Expand Down Expand Up @@ -56,7 +60,11 @@ impl<E: Dtype + CudaTypeName> super::ReshapeKernel<E> for Cuda {
let module = std::format!("reshape_bwd_{}", E::NAME);
if !self.dev.has_func(&module, "reshape_bwd") {
let src = BWD_KERNEL.replace("$T", E::NAME);
let ptx = compile_ptx(src).unwrap();
let opts = CompileOptions {
arch: Some(env!("CUDA_COMPUTE_CAP")),
..Default::default()
};
let ptx = compile_ptx_with_opts(src, opts).unwrap();
self.dev.load_ptx(ptx, &module, &["reshape_bwd"])?;
}
let bwd_fn = self.dev.get_func(&module, "reshape_bwd").unwrap();
Expand Down

0 comments on commit ee64526

Please sign in to comment.