-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shuffle_axis_inplace
and shuffle_axis_inplace_using
#742
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,8 +34,7 @@ use crate::rand::rngs::SmallRng; | |
use crate::rand::seq::index; | ||
use crate::rand::{thread_rng, Rng, SeedableRng}; | ||
|
||
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder}; | ||
use ndarray::{ArrayBase, DataOwned, Dimension}; | ||
use ndarray::{Array, Axis, Data, DataMut, RemoveAxis, ShapeBuilder, ArrayBase, DataOwned, Dimension}; | ||
#[cfg(feature = "quickcheck")] | ||
use quickcheck::{Arbitrary, Gen}; | ||
|
||
|
@@ -64,7 +63,7 @@ pub mod rand_distr { | |
/// [`.random_using()`](#tymethod.random_using). | ||
pub trait RandomExt<S, A, D> | ||
where | ||
S: DataOwned<Elem = A>, | ||
S: Data<Elem = A>, | ||
D: Dimension, | ||
{ | ||
/// Create an array with shape `dim` with elements drawn from | ||
|
@@ -87,6 +86,7 @@ where | |
/// # } | ||
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D> | ||
where | ||
S: DataOwned<Elem=A>, | ||
IdS: Distribution<S::Elem>, | ||
Sh: ShapeBuilder<Dim = D>; | ||
|
||
|
@@ -116,6 +116,7 @@ where | |
/// # } | ||
fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D> | ||
where | ||
S: DataOwned<Elem=A>, | ||
IdS: Distribution<S::Elem>, | ||
R: Rng + ?Sized, | ||
Sh: ShapeBuilder<Dim = D>; | ||
|
@@ -225,17 +226,93 @@ where | |
R: Rng + ?Sized, | ||
A: Copy, | ||
D: RemoveAxis; | ||
|
||
/// Shuffle `self`'s slices along `axis`. | ||
/// | ||
/// It uses [Fisher-Yates shuffling algorithm](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle). | ||
/// | ||
/// ***Panics*** when creation of the RNG fails. | ||
/// ``` | ||
/// use ndarray::{array, Axis}; | ||
/// use ndarray_rand::RandomExt; | ||
/// | ||
/// # fn main() { | ||
/// let mut a = array![ | ||
/// [1., 2., 3.], | ||
/// [4., 5., 6.], | ||
/// [7., 8., 9.], | ||
/// [10., 11., 12.], | ||
/// ]; | ||
/// // Let's shuffle `a`'s columns! | ||
/// // Shuffling modifies the array in place, nothing is returned | ||
/// a.shuffle_axis_inplace(Axis(1)); | ||
/// println!("{:?}", a); | ||
/// // Example Output: | ||
/// // [ | ||
/// // [1., 3., 2.], | ||
/// // [4., 6., 5.], | ||
/// // [7., 9., 8.], | ||
/// // [10., 12., 11.], | ||
/// // ] | ||
/// # } | ||
/// ``` | ||
fn shuffle_axis_inplace(&mut self, axis: Axis) | ||
where | ||
D: RemoveAxis, | ||
S: DataMut; | ||
|
||
/// Shuffle `self`'s slices along `axis` using the specified random number generator `rng`. | ||
/// | ||
/// It uses [Fisher-Yates shuffling algorithm](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle). | ||
/// | ||
/// ***Panics*** when creation of the RNG fails. | ||
/// ``` | ||
/// use ndarray::{array, Axis}; | ||
/// use ndarray_rand::RandomExt; | ||
/// use ndarray_rand::rand::SeedableRng; | ||
/// use rand_isaac::isaac64::Isaac64Rng; | ||
/// | ||
/// # fn main() { | ||
/// // Get a seeded random number generator for reproducibility (Isaac64 algorithm) | ||
/// let seed = 42; | ||
/// let mut rng = Isaac64Rng::seed_from_u64(seed); | ||
/// | ||
/// let mut a = array![ | ||
/// [1., 2., 3.], | ||
/// [4., 5., 6.], | ||
/// [7., 8., 9.], | ||
/// [10., 11., 12.], | ||
/// ]; | ||
/// // Let's shuffle `a`'s rows! | ||
/// // Shuffling modifies the array in place, nothing is returned | ||
/// a.shuffle_axis_inplace_using(Axis(0), &mut rng); | ||
/// println!("{:?}", a); | ||
/// // Example Output: | ||
/// // [ | ||
/// // [7., 8., 9.], | ||
/// // [4., 5., 6.], | ||
/// // [10., 11., 12.], | ||
/// // [1., 2., 3.], | ||
/// // ] | ||
/// # } | ||
/// ``` | ||
fn shuffle_axis_inplace_using<R>(&mut self, axis: Axis, rng: &mut R) | ||
where | ||
R: Rng + ?Sized, | ||
D: RemoveAxis, | ||
S: DataMut; | ||
} | ||
|
||
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D> | ||
where | ||
S: DataOwned<Elem = A>, | ||
S: Data<Elem = A>, | ||
D: Dimension, | ||
{ | ||
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D> | ||
where | ||
IdS: Distribution<S::Elem>, | ||
Sh: ShapeBuilder<Dim = D>, | ||
S: DataOwned<Elem=A>, | ||
{ | ||
Self::random_using(shape, dist, &mut get_rng()) | ||
} | ||
|
@@ -245,6 +322,7 @@ where | |
IdS: Distribution<S::Elem>, | ||
R: Rng + ?Sized, | ||
Sh: ShapeBuilder<Dim = D>, | ||
S: DataOwned<Elem=A>, | ||
{ | ||
Self::from_shape_fn(shape, |_| dist.sample(rng)) | ||
} | ||
|
@@ -280,6 +358,41 @@ where | |
}; | ||
self.select(axis, &indices) | ||
} | ||
|
||
fn shuffle_axis_inplace(&mut self, axis: Axis) | ||
where | ||
D: RemoveAxis, | ||
S: DataMut, | ||
{ | ||
self.shuffle_axis_inplace_using(axis, &mut get_rng()) | ||
} | ||
|
||
fn shuffle_axis_inplace_using<R>(&mut self, axis: Axis, rng: &mut R) | ||
where | ||
R: Rng + ?Sized, | ||
D: RemoveAxis, | ||
S: DataMut, | ||
{ | ||
for i in (1..self.len_of(axis)).rev() { | ||
// Invariant: elements with index > i have been locked in place. | ||
let j = rng.gen_range(0, i + 1); | ||
|
||
if i != j { | ||
// Swap the two slices along `axis` | ||
let slice1 = self.index_axis(axis, i); | ||
let slice2 = self.index_axis(axis, j); | ||
|
||
for (x, y) in slice1.iter().zip(slice2.iter()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are modifying through read-only/shared array views and this is invalid, unless we have interior mutability (i.e using Cell, Mutex or similar). And we'd like to avoid There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm working on making There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
On the other hand the method is taking a mutable reference to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modifying through an ArrayView (without the Cell exception) is invalid full stop, because we treat it as we would treat a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Exceptions would come in to place if you somehow got a raw pointer out of an array view, maybe you'd then be able to make an argument. But since we go through explicit shared references here, we are at invalid full stop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, no problem. The types matter and there are shared refs used on the path to the swap. If we read the rust reference document this is UB, we can't just cast from &T and mutate. |
||
// Swap the two elements. | ||
let ptr1 = x as *const A as *mut A; | ||
let ptr2 = y as *const A as *mut A; | ||
unsafe { | ||
std::ptr::swap(ptr1, ptr2); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shuffle the array's xyz along
axis
.We need a name for it, not sure if it should be slices. For the
.lanes(axis)
method we call them lanes - but those are perpendicular toaxis
in that case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be tempted to call them planes. We shuffle among the items in an AxisIter, right? Each such item is n - 1 dimensional in an ndarray.
That said it's good to have 2D explanations - how to shuffle rows and columns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative name could be projections, which has the plus of not carrying an intuitive dimensionality along with it (I think 2D when I read planes).