Skip to content

Commit

Permalink
Add a sort function. (#2134)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Apr 28, 2024
1 parent 3b429f3 commit 805f3be
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
17 changes: 17 additions & 0 deletions candle-core/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,21 @@ impl Tensor {
// No need for a backward pass for arg sort.
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
}

/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
/// sorted indexes.
///
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
/// descending order. The sort is unstable so there is no guarantees on the final order when it
/// comes to ties.
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
if !self.is_contiguous() {
return Err(crate::Error::RequiresContiguous {
op: "sort_last_dim",
});
}
let asort = self.arg_sort_last_dim(asc)?;
let sorted = self.gather(&asort, crate::D::Minus1)?;
Ok((sorted, asort))
}
}
18 changes: 18 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,24 @@ fn asort(device: &Device) -> Result<()> {
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
let (sorted, indexes) = tensor.sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
assert_eq!(
sorted.to_vec2::<f32>()?,
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
);
let (sorted, indexes) = tensor.sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
assert_eq!(
sorted.to_vec2::<f32>()?,
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
);
Ok(())
}

Expand Down

0 comments on commit 805f3be

Please sign in to comment.