diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 65a8fbf289..3a8a0bb915 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -77,6 +77,27 @@ fn rms_norm(device: &Device) -> Result<()> { Ok(()) } +fn rms_norml(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, seq_len, head_dim) = (24, 70, 64); + let el_count = b_size * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?; + let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?; + let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + let diff = (t - t2)? + .abs()? + .flatten_all()? + .max(0)? + .reshape(())? + .to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + fn layer_norm(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; let tensor = Tensor::new(data, device)?; @@ -103,6 +124,28 @@ fn layer_norm(device: &Device) -> Result<()> { Ok(()) } +fn layer_norml(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, seq_len, head_dim) = (24, 70, 64); + let el_count = b_size * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?; + let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?; + let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?; + let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?; + let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?; + let diff = (t - t2)? + .abs()? + .flatten_all()? + .max(0)? + .reshape(())? + .to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + #[test] fn softmax_numerical_stability() -> Result<()> { let dev = &Device::Cpu; @@ -211,5 +254,7 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); +test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal); test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal); +test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal); test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);