Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shuffle_axis_inplace and shuffle_axis_inplace_using #742

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 117 additions & 4 deletions ndarray-rand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{thread_rng, Rng, SeedableRng};

use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, DataOwned, Dimension};
use ndarray::{Array, Axis, Data, DataMut, RemoveAxis, ShapeBuilder, ArrayBase, DataOwned, Dimension};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};

Expand Down Expand Up @@ -64,7 +63,7 @@ pub mod rand_distr {
/// [`.random_using()`](#tymethod.random_using).
pub trait RandomExt<S, A, D>
where
S: DataOwned<Elem = A>,
S: Data<Elem = A>,
D: Dimension,
{
/// Create an array with shape `dim` with elements drawn from
Expand All @@ -87,6 +86,7 @@ where
/// # }
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
where
S: DataOwned<Elem=A>,
IdS: Distribution<S::Elem>,
Sh: ShapeBuilder<Dim = D>;

Expand Down Expand Up @@ -116,6 +116,7 @@ where
/// # }
fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D>
where
S: DataOwned<Elem=A>,
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
Sh: ShapeBuilder<Dim = D>;
Expand Down Expand Up @@ -225,17 +226,93 @@ where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis;

/// Shuffle `self`'s slices along `axis`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shuffle the array's xyz along axis.

We need a name for it, not sure if it should be slices. For the .lanes(axis) method we call them lanes - but those are perpendicular to axis in that case.

Copy link
Member

@bluss bluss Oct 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be tempted to call them planes. We shuffle among the items in an AxisIter, right? Each such item is n - 1 dimensional in an ndarray.

That said it's good to have 2D explanations - how to shuffle rows and columns.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative name could be projections, which has the plus of not carrying an intuitive dimensionality along with it (I think 2D when I read planes).

///
/// It uses [Fisher-Yates shuffling algorithm](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle).
///
/// ***Panics*** when creation of the RNG fails.
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_rand::RandomExt;
///
/// # fn main() {
/// let mut a = array![
/// [1., 2., 3.],
/// [4., 5., 6.],
/// [7., 8., 9.],
/// [10., 11., 12.],
/// ];
/// // Let's shuffle `a`'s columns!
/// // Shuffling modifies the array in place, nothing is returned
/// a.shuffle_axis_inplace(Axis(1));
/// println!("{:?}", a);
/// // Example Output:
/// // [
/// // [1., 3., 2.],
/// // [4., 6., 5.],
/// // [7., 9., 8.],
/// // [10., 12., 11.],
/// // ]
/// # }
/// ```
fn shuffle_axis_inplace(&mut self, axis: Axis)
where
D: RemoveAxis,
S: DataMut;

/// Shuffle `self`'s slices along `axis` using the specified random number generator `rng`.
///
/// It uses [Fisher-Yates shuffling algorithm](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle).
///
/// ***Panics*** when creation of the RNG fails.
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_rand::RandomExt;
/// use ndarray_rand::rand::SeedableRng;
/// use rand_isaac::isaac64::Isaac64Rng;
///
/// # fn main() {
/// // Get a seeded random number generator for reproducibility (Isaac64 algorithm)
/// let seed = 42;
/// let mut rng = Isaac64Rng::seed_from_u64(seed);
///
/// let mut a = array![
/// [1., 2., 3.],
/// [4., 5., 6.],
/// [7., 8., 9.],
/// [10., 11., 12.],
/// ];
/// // Let's shuffle `a`'s rows!
/// // Shuffling modifies the array in place, nothing is returned
/// a.shuffle_axis_inplace_using(Axis(0), &mut rng);
/// println!("{:?}", a);
/// // Example Output:
/// // [
/// // [7., 8., 9.],
/// // [4., 5., 6.],
/// // [10., 11., 12.],
/// // [1., 2., 3.],
/// // ]
/// # }
/// ```
fn shuffle_axis_inplace_using<R>(&mut self, axis: Axis, rng: &mut R)
where
R: Rng + ?Sized,
D: RemoveAxis,
S: DataMut;
}

impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: DataOwned<Elem = A>,
S: Data<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
Sh: ShapeBuilder<Dim = D>,
S: DataOwned<Elem=A>,
{
Self::random_using(shape, dist, &mut get_rng())
}
Expand All @@ -245,6 +322,7 @@ where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
Sh: ShapeBuilder<Dim = D>,
S: DataOwned<Elem=A>,
{
Self::from_shape_fn(shape, |_| dist.sample(rng))
}
Expand Down Expand Up @@ -280,6 +358,41 @@ where
};
self.select(axis, &indices)
}

fn shuffle_axis_inplace(&mut self, axis: Axis)
where
D: RemoveAxis,
S: DataMut,
{
self.shuffle_axis_inplace_using(axis, &mut get_rng())
}

fn shuffle_axis_inplace_using<R>(&mut self, axis: Axis, rng: &mut R)
where
R: Rng + ?Sized,
D: RemoveAxis,
S: DataMut,
{
for i in (1..self.len_of(axis)).rev() {
// Invariant: elements with index > i have been locked in place.
let j = rng.gen_range(0, i + 1);

if i != j {
// Swap the two slices along `axis`
let slice1 = self.index_axis(axis, i);
let slice2 = self.index_axis(axis, j);

for (x, y) in slice1.iter().zip(slice2.iter()) {
Copy link
Member

@bluss bluss Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are modifying through read-only/shared array views and this is invalid, unless we have interior mutability (i.e using Cell, Mutex or similar).

And we'd like to avoid iter().zip() for ndarray, because it is much slower than it needs to be, we have Zip.

Copy link
Member

@bluss bluss Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on making ArrayView<Cell<T>, D> a reality, a conversion we can use sometimes, - it allows us to have.. more fun like this. :)

Copy link
Member Author

@LukeMathWalker LukeMathWalker Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are modifying through read-only/shared array views and this is invalid, unless we have interior mutability (i.e using Cell, Mutex or similar).

On the other hand the method is taking a mutable reference to self, hence we are guaranteed to be the only ones operating on the array - we are opting out of the compiler safety guarantees because we know that two slices are not overlapping and that we have a unique handle on the data, hence the references are valid and to the outside world the operation should be safe.
Or am I interpreting this incorrectly?

Copy link
Member

@bluss bluss Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modifying through an ArrayView (without the Cell exception) is invalid full stop, because we treat it as we would treat a &[T] or &T. (We could go into details and exceptions at another time, maybe there's some argument for a more lenient view?)

Copy link
Member

@bluss bluss Oct 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x is a &T but we are modifying through it. That's not permitted.

Exceptions would come in to place if you somehow got a raw pointer out of an array view, maybe you'd then be able to make an argument. But since we go through explicit shared references here, we are at invalid full stop.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the S: DataMut trait bound on the method, aren't we dealing with an ArrayViewMut, which is equivalent to a &mut [T]?
Just trying to truly understand, I am not trying to argue 😅

Copy link
Member

@bluss bluss Oct 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, no problem. The types matter and there are shared refs used on the path to the swap. If we read the rust reference document this is UB, we can't just cast from &T and mutate.

// Swap the two elements.
let ptr1 = x as *const A as *mut A;
let ptr2 = y as *const A as *mut A;
unsafe {
std::ptr::swap(ptr1, ptr2);
}
}
}
}
}
}

/// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine
Expand Down
17 changes: 17 additions & 0 deletions ndarray-rand/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,20 @@ fn sampling_with_replacement_from_a_zero_length_axis_should_panic() {
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement);
}

quickcheck! {
fn shuffling_works(m: usize, n: usize) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));

// Get a clone of `a` and shuffle it in place
let mut results = vec![];
for &axis in &[Axis(0), Axis(1)] {
let mut b = a.clone();
b.shuffle_axis_inplace(axis);

let result = b.axis_iter(axis).all(|lane| is_subset(&a, &lane, axis));
results.push(result)
}
results.into_iter().all(|p| p)
}
}