diff --git a/Cargo.toml b/Cargo.toml index b3dafd0f..f384438e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ quickcheck = { version = "0.8.1", default-features = false } ndarray-rand = "0.9" approx = "0.3" quickcheck_macros = "0.8" +num-bigint = "0.2.2" [[bench]] name = "sort" diff --git a/src/deviation.rs b/src/deviation.rs new file mode 100644 index 00000000..c9f7c77c --- /dev/null +++ b/src/deviation.rs @@ -0,0 +1,376 @@ +use ndarray::{ArrayBase, Data, Dimension, Zip}; +use num_traits::{Signed, ToPrimitive}; +use std::convert::Into; +use std::ops::AddAssign; + +use crate::errors::{MultiInputError, ShapeMismatch}; + +/// An extension trait for `ArrayBase` providing functions +/// to compute different deviation measures. +pub trait DeviationExt +where + S: Data, + D: Dimension, +{ + /// Counts the number of indices at which the elements of the arrays `self` + /// and `other` are equal. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + fn count_eq(&self, other: &ArrayBase) -> Result + where + A: PartialEq; + + /// Counts the number of indices at which the elements of the arrays `self` + /// and `other` are not equal. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + fn count_neq(&self, other: &ArrayBase) -> Result + where + A: PartialEq; + + /// Computes the [squared L2 distance] between `self` and `other`. + /// + /// ```text + /// n + /// ∑ |aᵢ - bᵢ|² + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance + fn sq_l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed; + + /// Computes the [L2 distance] between `self` and `other`. + /// + /// ```text + /// n + /// √ ( ∑ |aᵢ - bᵢ|² ) + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance + fn l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive; + + /// Computes the [L1 distance] between `self` and `other`. + /// + /// ```text + /// n + /// ∑ |aᵢ - bᵢ| + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry + fn l1_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed; + + /// Computes the [L∞ distance] between `self` and `other`. + /// + /// ```text + /// max(|aᵢ - bᵢ|) + /// ᵢ + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance + fn linf_dist(&self, other: &ArrayBase) -> Result + where + A: Clone + PartialOrd + Signed; + + /// Computes the [mean absolute error] between `self` and `other`. + /// + /// ```text + /// n + /// 1/n * ∑ |aᵢ - bᵢ| + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error + fn mean_abs_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive; + + /// Computes the [mean squared error] between `self` and `other`. + /// + /// ```text + /// n + /// 1/n * ∑ |aᵢ - bᵢ|² + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error + fn mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive; + + /// Computes the unnormalized [root-mean-square error] between `self` and `other`. + /// + /// ```text + /// √ mse(a, b) + /// ``` + /// + /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation + fn root_mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive; + + /// Computes the [peak signal-to-noise ratio] between `self` and `other`. + /// + /// ```text + /// 10 * log10(maxv^2 / mse(a, b)) + /// ``` + /// + /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error + /// and `maxv` is the maximum possible value either array can take. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + fn peak_signal_to_noise_ratio( + &self, + other: &ArrayBase, + maxv: A, + ) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive; + + private_decl! {} +} + +macro_rules! return_err_if_empty { + ($arr:expr) => { + if $arr.len() == 0 { + return Err(MultiInputError::EmptyInput); + } + }; +} +macro_rules! return_err_unless_same_shape { + ($arr_a:expr, $arr_b:expr) => { + if $arr_a.shape() != $arr_b.shape() { + return Err(ShapeMismatch { + first_shape: $arr_a.shape().to_vec(), + second_shape: $arr_b.shape().to_vec(), + } + .into()); + } + }; +} + +impl DeviationExt for ArrayBase +where + S: Data, + D: Dimension, +{ + fn count_eq(&self, other: &ArrayBase) -> Result + where + A: PartialEq, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut count = 0; + + Zip::from(self).and(other).apply(|a, b| { + if a == b { + count += 1; + } + }); + + Ok(count) + } + + fn count_neq(&self, other: &ArrayBase) -> Result + where + A: PartialEq, + { + self.count_eq(other).map(|n_eq| self.len() - n_eq) + } + + fn sq_l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut result = A::zero(); + + Zip::from(self).and(other).apply(|self_i, other_i| { + let (a, b) = (self_i.clone(), other_i.clone()); + let abs_diff = (a - b).abs(); + result += abs_diff.clone() * abs_diff; + }); + + Ok(result) + } + + fn l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + { + let sq_l2_dist = self + .sq_l2_dist(other)? + .to_f64() + .expect("failed cast from type A to f64"); + + Ok(sq_l2_dist.sqrt()) + } + + fn l1_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut result = A::zero(); + + Zip::from(self).and(other).apply(|self_i, other_i| { + let (a, b) = (self_i.clone(), other_i.clone()); + result += (a - b).abs(); + }); + + Ok(result) + } + + fn linf_dist(&self, other: &ArrayBase) -> Result + where + A: Clone + PartialOrd + Signed, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut max = A::zero(); + + Zip::from(self).and(other).apply(|self_i, other_i| { + let (a, b) = (self_i.clone(), other_i.clone()); + let diff = (a - b).abs(); + if diff > max { + max = diff; + } + }); + + Ok(max) + } + + fn mean_abs_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + { + let l1_dist = self + .l1_dist(other)? + .to_f64() + .expect("failed cast from type A to f64"); + let n = self.len() as f64; + + Ok(l1_dist / n) + } + + fn mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + { + let sq_l2_dist = self + .sq_l2_dist(other)? + .to_f64() + .expect("failed cast from type A to f64"); + let n = self.len() as f64; + + Ok(sq_l2_dist / n) + } + + fn root_mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + { + let msd = self.mean_sq_err(other)?; + Ok(msd.sqrt()) + } + + fn peak_signal_to_noise_ratio( + &self, + other: &ArrayBase, + maxv: A, + ) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + { + let maxv_f = maxv.to_f64().expect("failed cast from type A to f64"); + let msd = self.mean_sq_err(&other)?; + let psnr = 10. * f64::log10(maxv_f * maxv_f / msd); + + Ok(psnr) + } + + private_impl! {} +} diff --git a/src/errors.rs b/src/errors.rs index 2386a301..e2617f39 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -46,7 +46,7 @@ impl From for MinMaxError { /// An error used by methods and functions that take two arrays as argument and /// expect them to have exactly the same shape /// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct ShapeMismatch { pub first_shape: Vec, pub second_shape: Vec, @@ -65,7 +65,7 @@ impl fmt::Display for ShapeMismatch { impl Error for ShapeMismatch {} /// An error for methods that take multiple non-empty array inputs. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum MultiInputError { /// One or more of the arrays were empty. EmptyInput, diff --git a/src/lib.rs b/src/lib.rs index 6cf7f5ed..9ee3d350 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ //! - [partitioning]; //! - [correlation analysis] (covariance, pearson correlation); //! - [measures from information theory] (entropy, KL divergence, etc.); +//! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.) //! - [histogram computation]. //! //! Please feel free to contribute new functionality! A roadmap can be found [here]. @@ -21,6 +22,7 @@ //! [partitioning]: trait.Sort1dExt.html //! [summary statistics]: trait.SummaryStatisticsExt.html //! [correlation analysis]: trait.CorrelationExt.html +//! [measures of deviation]: trait.DeviationExt.html //! [measures from information theory]: trait.EntropyExt.html //! [histogram computation]: histogram/index.html //! [here]: https://github.com/rust-ndarray/ndarray-stats/issues/1 @@ -28,6 +30,7 @@ //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ pub use crate::correlation::CorrelationExt; +pub use crate::deviation::DeviationExt; pub use crate::entropy::EntropyExt; pub use crate::histogram::HistogramExt; pub use crate::maybe_nan::{MaybeNan, MaybeNanExt}; @@ -69,6 +72,7 @@ mod private { } mod correlation; +mod deviation; mod entropy; pub mod errors; pub mod histogram; diff --git a/tests/deviation.rs b/tests/deviation.rs new file mode 100644 index 00000000..cae4aa89 --- /dev/null +++ b/tests/deviation.rs @@ -0,0 +1,252 @@ +use ndarray_stats::errors::{MultiInputError, ShapeMismatch}; +use ndarray_stats::DeviationExt; + +use approx::assert_abs_diff_eq; +use ndarray::{array, Array1}; +use num_bigint::BigInt; +use num_traits::Float; + +use std::f64; + +#[test] +fn test_count_eq() -> Result<(), MultiInputError> { + let a = array![0., 0.]; + let b = array![1., 0.]; + let c = array![0., 1.]; + let d = array![1., 1.]; + + assert_eq!(a.count_eq(&a)?, 2); + assert_eq!(a.count_eq(&b)?, 1); + assert_eq!(a.count_eq(&c)?, 1); + assert_eq!(a.count_eq(&d)?, 0); + + Ok(()) +} + +#[test] +fn test_count_neq() -> Result<(), MultiInputError> { + let a = array![0., 0.]; + let b = array![1., 0.]; + let c = array![0., 1.]; + let d = array![1., 1.]; + + assert_eq!(a.count_neq(&a)?, 0); + assert_eq!(a.count_neq(&b)?, 1); + assert_eq!(a.count_neq(&c)?, 1); + assert_eq!(a.count_neq(&d)?, 2); + + Ok(()) +} + +#[test] +fn test_sq_l2_dist() -> Result<(), MultiInputError> { + let a = array![0., 1., 4., 2.]; + let b = array![1., 1., 2., 4.]; + + assert_eq!(a.sq_l2_dist(&b)?, 9.); + + Ok(()) +} + +#[test] +fn test_l2_dist() -> Result<(), MultiInputError> { + let a = array![0., 1., 4., 2.]; + let b = array![1., 1., 2., 4.]; + + assert_eq!(a.l2_dist(&b)?, 3.); + + Ok(()) +} + +#[test] +fn test_l1_dist() -> Result<(), MultiInputError> { + let a = array![0., 1., 4., 2.]; + let b = array![1., 1., 2., 4.]; + + assert_eq!(a.l1_dist(&b)?, 5.); + + Ok(()) +} + +#[test] +fn test_linf_dist() -> Result<(), MultiInputError> { + let a = array![0., 0.]; + let b = array![1., 0.]; + let c = array![1., 2.]; + + assert_eq!(a.linf_dist(&a)?, 0.); + + assert_eq!(a.linf_dist(&b)?, 1.); + assert_eq!(b.linf_dist(&a)?, 1.); + + assert_eq!(a.linf_dist(&c)?, 2.); + assert_eq!(c.linf_dist(&a)?, 2.); + + Ok(()) +} + +#[test] +fn test_mean_abs_err() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + let b = array![3., 5.]; + + assert_eq!(a.mean_abs_err(&a)?, 0.); + assert_eq!(a.mean_abs_err(&b)?, 3.); + assert_eq!(b.mean_abs_err(&a)?, 3.); + + Ok(()) +} + +#[test] +fn test_mean_sq_err() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + let b = array![3., 5.]; + + assert_eq!(a.mean_sq_err(&a)?, 0.); + assert_eq!(a.mean_sq_err(&b)?, 10.); + assert_eq!(b.mean_sq_err(&a)?, 10.); + + Ok(()) +} + +#[test] +fn test_root_mean_sq_err() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + let b = array![3., 5.]; + + assert_eq!(a.root_mean_sq_err(&a)?, 0.); + assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 10.0.sqrt()); + assert_abs_diff_eq!(b.root_mean_sq_err(&a)?, 10.0.sqrt()); + + Ok(()) +} + +#[test] +fn test_peak_signal_to_noise_ratio() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + assert!(a.peak_signal_to_noise_ratio(&a, 1.)?.is_infinite()); + + let a = array![1., 2., 3., 4., 5., 6., 7.]; + let b = array![1., 3., 3., 4., 6., 7., 8.]; + let maxv = 8.; + let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_err(&b)?); + let actual = a.peak_signal_to_noise_ratio(&b, maxv)?; + + assert_abs_diff_eq!(actual, expected); + + Ok(()) +} + +#[test] +fn test_deviations_with_n_by_m_ints() -> Result<(), MultiInputError> { + let a = array![[0, 1], [4, 2]]; + let b = array![[1, 1], [2, 4]]; + + assert_eq!(a.count_eq(&a)?, 4); + assert_eq!(a.count_neq(&a)?, 0); + + assert_eq!(a.sq_l2_dist(&b)?, 9); + assert_eq!(a.l2_dist(&b)?, 3.); + assert_eq!(a.l1_dist(&b)?, 5); + assert_eq!(a.linf_dist(&b)?, 2); + + assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); + assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); + assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); + assert_abs_diff_eq!(a.peak_signal_to_noise_ratio(&b, 4)?, 8.519374645445623); + + Ok(()) +} + +#[test] +fn test_deviations_with_empty_receiver() { + let a: Array1 = array![]; + let b: Array1 = array![1.]; + + assert_eq!(a.count_eq(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.count_neq(&b), Err(MultiInputError::EmptyInput)); + + assert_eq!(a.sq_l2_dist(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.l2_dist(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.l1_dist(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.linf_dist(&b), Err(MultiInputError::EmptyInput)); + + assert_eq!(a.mean_abs_err(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.mean_sq_err(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.root_mean_sq_err(&b), Err(MultiInputError::EmptyInput)); + assert_eq!( + a.peak_signal_to_noise_ratio(&b, 0.), + Err(MultiInputError::EmptyInput) + ); +} + +#[test] +fn test_deviations_do_not_panic_if_nans() -> Result<(), MultiInputError> { + let a: Array1 = array![1., f64::NAN, 3., f64::NAN]; + let b: Array1 = array![1., f64::NAN, 3., 4.]; + + assert_eq!(a.count_eq(&b)?, 2); + assert_eq!(a.count_neq(&b)?, 2); + + assert!(a.sq_l2_dist(&b)?.is_nan()); + assert!(a.l2_dist(&b)?.is_nan()); + assert!(a.l1_dist(&b)?.is_nan()); + assert_eq!(a.linf_dist(&b)?, 0.); + + assert!(a.mean_abs_err(&b)?.is_nan()); + assert!(a.mean_sq_err(&b)?.is_nan()); + assert!(a.root_mean_sq_err(&b)?.is_nan()); + assert!(a.peak_signal_to_noise_ratio(&b, 0.)?.is_nan()); + + Ok(()) +} + +#[test] +fn test_deviations_with_empty_argument() { + let a: Array1 = array![1.]; + let b: Array1 = array![]; + + let shape_mismatch_err = MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: a.shape().to_vec(), + second_shape: b.shape().to_vec(), + }); + let expected_err_usize = Err(shape_mismatch_err.clone()); + let expected_err_f64 = Err(shape_mismatch_err); + + assert_eq!(a.count_eq(&b), expected_err_usize); + assert_eq!(a.count_neq(&b), expected_err_usize); + + assert_eq!(a.sq_l2_dist(&b), expected_err_f64); + assert_eq!(a.l2_dist(&b), expected_err_f64); + assert_eq!(a.l1_dist(&b), expected_err_f64); + assert_eq!(a.linf_dist(&b), expected_err_f64); + + assert_eq!(a.mean_abs_err(&b), expected_err_f64); + assert_eq!(a.mean_sq_err(&b), expected_err_f64); + assert_eq!(a.root_mean_sq_err(&b), expected_err_f64); + assert_eq!(a.peak_signal_to_noise_ratio(&b, 0.), expected_err_f64); +} + +#[test] +fn test_deviations_with_non_copyable() -> Result<(), MultiInputError> { + let a: Array1 = array![0.into(), 1.into(), 4.into(), 2.into()]; + let b: Array1 = array![1.into(), 1.into(), 2.into(), 4.into()]; + + assert_eq!(a.count_eq(&a)?, 4); + assert_eq!(a.count_neq(&a)?, 0); + + assert_eq!(a.sq_l2_dist(&b)?, 9.into()); + assert_eq!(a.l2_dist(&b)?, 3.); + assert_eq!(a.l1_dist(&b)?, 5.into()); + assert_eq!(a.linf_dist(&b)?, 2.into()); + + assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); + assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); + assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); + assert_abs_diff_eq!( + a.peak_signal_to_noise_ratio(&b, 4.into())?, + 8.519374645445623 + ); + + Ok(()) +}