Skip to content

Commit

Permalink
Fixing cuda kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Oct 25, 2023
1 parent ca9548b commit 9a2564a
Show file tree
Hide file tree
Showing 31 changed files with 52 additions and 81 deletions.
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor/cpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ impl Cpu {
numel: usize,
elem: E,
) -> Result<CachableVec<E>, Error> {
let data = self.cache.try_pop::<E>(numel).map_or_else(
let data: Result<Vec<E>, Error> = self.cache.try_pop::<E>(numel).map_or_else(
#[cfg(feature = "fast-alloc")]
|| Ok(std::vec![elem; numel]),
#[cfg(not(feature = "fast-alloc"))]
|| {
let mut data: Vec<E> = Vec::new();
data.try_reserve(numel).map_err(|_| CpuError::OutOfMemory)?;
data.try_reserve(numel).map_err(|_| Error::OutOfMemory)?;
data.resize(numel, elem);
Ok(data)
},
Expand All @@ -45,10 +45,10 @@ impl Cpu {
data.fill(elem);
Ok(data)
},
)?;
);

Ok(CachableVec {
data,
data: data?,
cache: self.cache.clone(),
})
}
Expand Down
8 changes: 4 additions & 4 deletions dfdx-core/src/tensor/cuda/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

use crate::{
shapes::*,
tensor::{masks::triangle_mask, storage_traits::*, unique_id, Cpu, CpuError, NoneTape, Tensor},
tensor::{masks::triangle_mask, storage_traits::*, unique_id, Cpu, Error, NoneTape, Tensor},
};

use super::{device::CachableCudaSlice, Cuda, CudaError};
use super::{device::CachableCudaSlice, Cuda};

use cudarc::driver::{CudaSlice, DeviceSlice};
use rand::Rng;
Expand All @@ -16,7 +16,7 @@ impl Cuda {
&self,
shape: S,
buf: Vec<E>,
) -> Result<Tensor<S, E, Self>, CudaError> {
) -> Result<Tensor<S, E, Self>, Error> {
let mut slice = unsafe { self.alloc_empty(buf.len()) }?;
self.dev.htod_copy_into(buf, &mut slice)?;
Ok(self.build_tensor(shape, shape.strides(), slice))
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<E: Unit> TensorFromVec<E> for Cuda {
let num_elements = shape.num_elements();

if src.len() != num_elements {
Err(CudaError::Cpu(CpuError::WrongNumElements))
Err(Error::WrongNumElements)
} else {
self.tensor_from_host_buf(shape, src)
}
Expand Down
53 changes: 14 additions & 39 deletions dfdx-core/src/tensor/cuda/device.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::shapes::{Shape, Unit};
use crate::tensor::cpu::{Cpu, CpuError};
use crate::tensor::cpu::Cpu;
use crate::tensor::{
cache::TensorCache, Cache, HasErr, NoneTape, RandomU64, Storage, Synchronize, Tensor,
cache::TensorCache, Cache, Error, NoneTape, RandomU64, Storage, Synchronize, Tensor,
};

use cudarc::driver::{DevicePtr, DevicePtrMut, DeviceRepr};
Expand Down Expand Up @@ -32,37 +32,22 @@ pub struct Cuda {
pub(crate) cache: Arc<TensorCache<CUdeviceptr>>,
}

#[derive(Debug)]
pub enum CudaError {
Blas(CublasError),
#[cfg(feature = "cudnn")]
Cudnn(cudarc::cudnn::CudnnError),
Driver(DriverError),
Cpu(CpuError),
}

impl From<CpuError> for CudaError {
fn from(value: CpuError) -> Self {
Self::Cpu(value)
}
}

impl From<CublasError> for CudaError {
impl From<CublasError> for Error {
fn from(value: CublasError) -> Self {
Self::Blas(value)
Self::CublasError(value)
}
}

impl From<DriverError> for CudaError {
impl From<DriverError> for Error {
fn from(value: DriverError) -> Self {
Self::Driver(value)
Self::CudaDriverError(value)
}
}

#[cfg(feature = "cudnn")]
impl From<cudarc::cudnn::CudnnError> for CudaError {
impl From<cudarc::cudnn::CudnnError> for Error {
fn from(value: cudarc::cudnn::CudnnError) -> Self {
Self::Cudnn(value)
Self::CudnnError(value)
}
}

Expand All @@ -79,12 +64,12 @@ impl Cuda {
}

/// Constructs rng with the given seed.
pub fn try_seed_from_u64(seed: u64) -> Result<Self, CudaError> {
pub fn try_seed_from_u64(seed: u64) -> Result<Self, Error> {
Self::try_build(0, seed)
}

/// Constructs with the given seed & device ordinal
pub fn try_build(ordinal: usize, seed: u64) -> Result<Self, CudaError> {
pub fn try_build(ordinal: usize, seed: u64) -> Result<Self, Error> {
let cpu = Cpu::seed_from_u64(seed);
let dev = CudaDevice::new(ordinal)?;
let blas = Arc::new(CudaBlas::new(dev.clone())?);
Expand Down Expand Up @@ -112,7 +97,7 @@ impl Cuda {
pub(crate) unsafe fn alloc_empty<E: DeviceRepr>(
&self,
len: usize,
) -> Result<CudaSlice<E>, CudaError> {
) -> Result<CudaSlice<E>, Error> {
let data = self.cache.try_pop::<E>(len).map_or_else(
|| self.dev.alloc::<E>(len),
|ptr| Ok(self.dev.upgrade_device_ptr(ptr, len)),
Expand All @@ -123,7 +108,7 @@ impl Cuda {
pub(crate) unsafe fn get_workspace<E>(
&self,
len: usize,
) -> Result<MutexGuard<CudaSlice<u8>>, CudaError> {
) -> Result<MutexGuard<CudaSlice<u8>>, Error> {
let num_bytes_required = len * std::mem::size_of::<E>();
let mut workspace = self.workspace.as_ref().lock().unwrap();

Expand All @@ -137,16 +122,6 @@ impl Cuda {
}
}

impl std::fmt::Display for CudaError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{self:?}")
}
}

impl HasErr for Cuda {
type Err = CudaError;
}

/// A [CudaSlice] that can be cloned without allocating new memory.
/// When [Drop]ed it will insert it's data into the cache.
#[derive(Debug)]
Expand Down Expand Up @@ -273,8 +248,8 @@ impl Cache for Cuda {
}

impl Synchronize for Cuda {
fn try_synchronize(&self) -> Result<(), CudaError> {
self.dev.synchronize().map_err(CudaError::from)
fn try_synchronize(&self) -> Result<(), Error> {
self.dev.synchronize().map_err(Error::from)
}
}

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/cuda/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod allocate;
mod device;

pub use device::{Cuda, CudaError};
pub use device::Cuda;

pub(crate) fn launch_cfg<const NUM_THREADS: u32>(n: u32) -> cudarc::driver::LaunchConfig {
let num_blocks = (n + NUM_THREADS - 1) / NUM_THREADS;
Expand Down
8 changes: 2 additions & 6 deletions dfdx-core/src/tensor/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub enum Error {
UnusedTensors(Vec<crate::tensor::UniqueId>),

#[cfg(feature = "cuda")]
CublasError(cudarc::cublas::CublasError),
CublasError(cudarc::cublas::result::CublasError),
#[cfg(feature = "cuda")]
CudaDriverError(cudarc::driver::DriverError),

Expand All @@ -19,11 +19,7 @@ pub enum Error {

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::OutOfMemory => f.write_str("Error::OutOfMemory"),
Self::WrongNumElements => f.write_str("Error::WrongNumElements"),
_ => todo!(),
}
write!(f, "{self:?}")
}
}

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub type AutoDevice = Cpu;
#[cfg(feature = "cuda")]
pub(crate) use cuda::launch_cfg;
#[cfg(feature = "cuda")]
pub use cuda::{Cuda, CudaError};
pub use cuda::Cuda;
#[cfg(feature = "cuda")]
pub type AutoDevice = Cuda;

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/adam/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
dtypes::*,
tensor::{launch_cfg, Cuda},
tensor::{launch_cfg, Cuda, Error},
tensor_ops::optim::*,
};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/axpy/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
dtypes::*,
tensor::{launch_cfg, Cuda},
tensor::{launch_cfg, Cuda, Error},
};

use cudarc::driver::{DeviceSlice, LaunchAsync};
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor_ops/boolean/cuda_kernels.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::BooleanKernel;
use crate::{
shapes::Shape,
tensor::{launch_cfg, Cuda, CudaError, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
};
use cudarc::driver::*;

Expand All @@ -15,7 +15,7 @@ impl Cuda {
fn_name: &str,
lhs: &Tensor<S, bool, Self>,
rhs: &Tensor<S, bool, Self>,
) -> Result<Tensor<S, bool, Self>, CudaError> {
) -> Result<Tensor<S, bool, Self>, Error> {
if !self.dev.has_func(MODULE_NAME, fn_name) {
self.dev
.load_ptx(PTX_SRC.into(), MODULE_NAME, &ALL_FN_NAMES)?;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/choose/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Storage, Tensor},
tensor::{launch_cfg, Cuda, Error, Storage, Tensor},
};
use cudarc::driver::{CudaSlice, LaunchAsync};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/cmp/cuda_kernels.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::Shape,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
};
use cudarc::driver::{CudaSlice, LaunchAsync};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/concat/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
};
use cudarc::{
driver::{DeviceSlice, LaunchAsync},
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/concat_along/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
shapes::*,
tensor::{launch_cfg, Cuda, GhostTensor, Tensor},
tensor::{launch_cfg, Cuda, Error, GhostTensor, Tensor},
};
use cudarc::{
driver::{DeviceSlice, LaunchAsync},
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/conv1d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits};
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor, Tensorlike},
tensor::{launch_cfg, Cuda, Error, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/conv2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits};
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor, Tensorlike},
tensor::{launch_cfg, Cuda, Error, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/conv2d/cudnn_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cudarc::driver::DeviceSlice;
use crate::{
dtypes::*,
shapes::*,
tensor::{Cuda, Tensor, Tensorlike},
tensor::{Cuda, Error, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/convtrans2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cudarc::driver::{DeviceRepr, LaunchAsync, ValidAsZeroBits};
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor, Tensorlike},
tensor::{launch_cfg, Cuda, Error, Tensor, Tensorlike},
};

use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/dropout/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
};

use std::vec::Vec;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/matmul/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::*,
tensor::{cuda::Cuda, Tensor},
tensor::{cuda::Cuda, Error, Tensor},
};

use cudarc::{
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/max_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
tensor_ops::reduction_utils::*,
};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/min_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
tensor_ops::reduction_utils::*,
};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/pool2d/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
dtypes::*,
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
};

use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/reshape_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
shapes::*,
tensor::{launch_cfg, Cuda, Tensor},
tensor::{launch_cfg, Cuda, Error, Tensor},
};
use cudarc::{
driver::{DeviceSlice, LaunchAsync},
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/rmsprop/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::RMSpropConfig;
use crate::{
dtypes::*,
tensor::{launch_cfg, Cuda},
tensor::{launch_cfg, Cuda, Error},
tensor_ops::optim::*,
};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/select_and_gather/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
dtypes::*,
shapes::{RemoveDimTo, ReplaceDimTo, Shape},
tensor::{launch_cfg, Cuda, Storage, Tensor},
tensor::{launch_cfg, Cuda, Error, Storage, Tensor},
};
use cudarc::driver::{DeviceSlice, LaunchAsync};

Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor_ops/sgd/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::SgdConfig;

use crate::{
dtypes::*,
tensor::{launch_cfg, Cuda},
tensor::{launch_cfg, Cuda, Error},
tensor_ops::optim::*,
};

Expand Down
Loading

0 comments on commit 9a2564a

Please sign in to comment.