Skip to content

Commit

Permalink
Add more testing for the fused layer/rms norm kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 1, 2024
1 parent 463ddac commit ee8beb7
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions candle-nn/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(())
}

fn rms_norml(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};

let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
let diff = (t - t2)?
.abs()?
.flatten_all()?
.max(0)?
.reshape(())?
.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}

fn layer_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
Expand All @@ -103,6 +124,28 @@ fn layer_norm(device: &Device) -> Result<()> {
Ok(())
}

fn layer_norml(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};

let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
let diff = (t - t2)?
.abs()?
.flatten_all()?
.max(0)?
.reshape(())?
.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}

#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
Expand Down Expand Up @@ -211,5 +254,7 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);

0 comments on commit ee8beb7

Please sign in to comment.