Skip to content

Commit

Permalink
Cuda support for the mnist training. (#277)
Browse files Browse the repository at this point in the history
* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
  • Loading branch information
LaurentMazare authored Jul 29, 2023
1 parent 16c3338 commit c950a5c
Show file tree
Hide file tree
Showing 6 changed files with 453 additions and 28 deletions.
2 changes: 1 addition & 1 deletion candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ impl ReduceIndex {
val = s
}
}
dst[unstr_index] = g(val, acc)
dst_to_set[unstr_index] = g(val, acc)
}
}
}
Expand Down
65 changes: 50 additions & 15 deletions candle-core/src/cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,28 @@ trait Map2InPlace {
}
}

trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;

fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}

trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
Expand Down Expand Up @@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> {
}

struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1 for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
impl<'a> Map1Any for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
wrap: W,
) -> Result<S> {
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
Expand Down Expand Up @@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> {
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
let name = match self.1 {
ReduceOp::Sum => "fast_sum",
ReduceOp::Min => "fast_min",
ReduceOp::Max => "fast_max",
ReduceOp::ArgMin => "fast_argmin",
ReduceOp::ArgMax => "fast_argmax",
let (name, check_empty, return_index) = match self.1 {
ReduceOp::Sum => ("fast_sum", false, false),
ReduceOp::Min => ("fast_min", true, false),
ReduceOp::Max => ("fast_max", true, false),
ReduceOp::ArgMin => ("fast_argmin", true, true),
ReduceOp::ArgMax => ("fast_argmax", true, true),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
if return_index {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(out))
} else {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(wrap(out))
}
}
}

Expand Down
276 changes: 276 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,278 @@ fn sum(device: &Device) -> Result<()> {
Ok(())
}

fn min(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.min_keepdim(2)?.to_vec3::<u32>()?,
&[[[1], [1]], [[1], [2]]]
);
assert_eq!(
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
&[[[2, 1, 4], [1, 2, 8]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);

// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);

let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.min_keepdim(0)?
.min_keepdim(2)?
.min_keepdim(1)?
.to_vec3::<u32>()?,
&[[[200]]]
);
assert_eq!(
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
&[[
[200, 201, 202, 203],
[204, 205, 206, 207],
[208, 209, 210, 211],
[212, 213, 214, 215],
[216, 217, 218, 219]
]]
);
}
Ok(())
}

fn max(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.max_keepdim(2)?.to_vec3::<u32>()?,
&[[[4], [9]], [[7], [8]]]
);
assert_eq!(
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
&[[[3, 1, 7], [8, 5, 9]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);

// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);

let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.max_keepdim(0)?
.max_keepdim(2)?
.max_keepdim(1)?
.to_vec3::<u32>()?,
&[[[3999]]]
);
assert_eq!(
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
&[[
[3980, 3981, 3982, 3983],
[3984, 3985, 3986, 3987],
[3988, 3989, 3990, 3991],
[3992, 3993, 3994, 3995],
[3996, 3997, 3998, 3999]
]]
);
}
Ok(())
}

fn argmin(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.argmin_keepdim(2)?.to_vec3::<u32>()?,
&[[[1], [0]], [[1], [1]]]
);
assert_eq!(
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
&[[[1, 0, 0], [0, 1, 1]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmin_keepdim(1)?
.argmin_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);

// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmin_keepdim(1)?
.argmin_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);

let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(2)?
.argmin_keepdim(1)?
.to_vec3::<u32>()?,
&[[[0]]]
);
assert_eq!(
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
&[[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
]]
);
}
Ok(())
}

fn argmax(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.argmax_keepdim(2)?.to_vec3::<u32>()?,
&[[[2], [2]], [[2], [0]]]
);
assert_eq!(
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
&[[[0, 0, 1], [1, 0, 0]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmax_keepdim(1)?
.argmax_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);

// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmax_keepdim(1)?
.argmax_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);

let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(2)?
.argmax_keepdim(1)?
.to_vec3::<u32>()?,
&[[[0]]]
);
assert_eq!(
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
&[[
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
]]
);
}
Ok(())
}

fn narrow(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
Expand Down Expand Up @@ -581,6 +853,10 @@ test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
Expand Down
Loading

0 comments on commit c950a5c

Please sign in to comment.