Skip to content

Commit

Permalink
experimental add metal impl
Browse files Browse the repository at this point in the history
  • Loading branch information
MilkFather committed Apr 28, 2024
1 parent 75de6f1 commit 37cad9d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu, sign
tanh, recip, silu, sign, sigmoid
);
}
pub mod binary {
Expand Down
5 changes: 5 additions & 0 deletions candle-metal-kernels/src/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ template <typename T> METAL_FUNC T relu(T in){
template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
template <typename T> METAL_FUNC T sigmoid(T in) {
return recip(static_cast<T>(1) + exp(-in));
}

#define TILE_SIZE 2

Expand Down Expand Up @@ -155,6 +158,7 @@ UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
UNARY_OP(sigmoid)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
Expand Down Expand Up @@ -185,6 +189,7 @@ BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
BFLOAT_UNARY_OP(sigmoid)

UNARY(id, bfloat, copy_bf16, copy_bf16_strided)

Expand Down
94 changes: 94 additions & 0 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,100 @@ impl candle::CustomOp1 for Sigmoid {
Ok((storage, layout.shape().clone()))
}

#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::MetalError;
let device = storage.device();
let dtype = storage.dtype();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label("sigmoid");
let src = candle_metal_kernels::BufferOffset {
buffer: storage.buffer(),
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
};

match (el_count % 2, dtype, layout.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::sigmoid::HALF,
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
dtype => {
candle::bail!(
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
)
}
};
candle_metal_kernels::call_unary_contiguous_tiled(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::sigmoid::HALF,
DType::F32 => contiguous::sigmoid::FLOAT,
DType::BF16 => contiguous::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::sigmoid::HALF,
DType::F32 => strided::sigmoid::FLOAT,
DType::BF16 => strided::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
}
};
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
layout.dims(),
src,
layout.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}

let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
Ok((new_storage, layout.shape().clone()))
}

fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
Expand Down

0 comments on commit 37cad9d

Please sign in to comment.