Skip to content

Commit

Permalink
Add take_along_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
andfoy committed Dec 4, 2023
1 parent c4889b5 commit 7a8c3c2
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 0 deletions.
49 changes: 49 additions & 0 deletions lib/extorch/native/tensor/ops/manipulation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1673,5 +1673,54 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do
"""
@spec take(ExTorch.Tensor.t(), ExTorch.Tensor.t()) :: ExTorch.Tensor.t()
defbinding(take(input, indices))

@doc """
Selects values from `input` at the 1-dimensional indices from `indices` along the given `dim`.
Functions that return indices along a dimension, like `ExTorch.argmax/3` and `ExTorch.argmin/3`,
are designed to work with this function. See the examples below.
## Arguments
- `input` (`ExTorch.Tensor`) - the input tensor.
- `indices` (`ExTorch.Tensor`) - the indices into `input`. It must have either `:long` or `:int64` dtype.
## Optional arguments
- `dim` (`integer`) - dimension to select along. If `nil`, then it will index all
the dimensions as a single one. Default: `nil`.
- `out` (`ExTorch.Tensor` or `nil`) - an optional pre-allocated tensor used to store the
output. Default: `nil`
## Notes
This function is similar to NumPy’s `take_along_axis`. See also `ExTorch.gather/5`.
## Examples
iex> t = ExTorch.tensor([[10, 30, 20], [60, 40, 50]], dtype: :long)
#Tensor<
[[10, 30, 20],
[60, 40, 50]]
[size: {2, 3}, dtype: :long, device: :cpu, requires_grad: false]>
iex> max_idx = ExTorch.argmax(t)
#Tensor< 3 [size: {}, dtype: :long, device: :cpu, requires_grad: false]>
iex> ExTorch.take_along_dim(t, max_idx)
#Tensor< [60] [size: {1}, dtype: :long, device: :cpu, requires_grad: false]>
iex> sorted_idx = ExTorch.argsort(t, dim: 1)
#Tensor<
[[0, 2, 1],
[1, 2, 0]]
[size: {2, 3}, dtype: :long, device: :cpu, requires_grad: false]>
iex> ExTorch.take_along_dim(t, sorted_idx, dim: 1)
#Tensor<
[[10, 20, 30],
[40, 50, 60]]
[size: {2, 3}, dtype: :long, device: :cpu, requires_grad: false]>
"""
@spec take_along_dim(
ExTorch.Tensor.t(),
ExTorch.Tensor.t(),
integer() | nil,
ExTorch.Tensor.t() | nil
) :: ExTorch.Tensor.t()
defbinding(take_along_dim(input, indices, dim \\ nil, out \\ nil))
end
end
6 changes: 6 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,9 @@ std::shared_ptr<CrossTensor> t(const std::shared_ptr<CrossTensor> &input);
std::shared_ptr<CrossTensor> take(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &indices);

std::shared_ptr<CrossTensor> take_along_dim(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &indices,
OptionalInt dim,
TensorOut out);
24 changes: 24 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,27 @@ std::shared_ptr<CrossTensor> take(
out_tensor = torch::take(in_tensor, indices_tensor);
return std::make_shared<CrossTensor>(std::move(out_tensor));
}

std::shared_ptr<CrossTensor> take_along_dim(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &indices,
OptionalInt dim,
TensorOut out) {

CrossTensor out_tensor;
CrossTensor in_tensor = *input.get();
CrossTensor indices_tensor = *indices.get();

torch::optional<int64_t> opt_dim = torch::nullopt;
if(dim.used) {
opt_dim = dim.value;
}

if(out.used) {
out_tensor = *out.tensor.get();
out_tensor = torch::take_along_dim_out(out_tensor, in_tensor, indices_tensor, opt_dim);
} else {
out_tensor = torch::take_along_dim(in_tensor, indices_tensor, opt_dim);
}
return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ rustler::init!(
stack,
t,
take,
take_along_dim,

// Tensor comparing operations
allclose,
Expand Down
8 changes: 8 additions & 0 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,11 @@ fn take(
input: &SharedPtr<CrossTensor>,
indices: &SharedPtr<CrossTensor>,
) -> Result<SharedPtr<CrossTensor>>;

/// Index a tensor alongside a dimension.
fn take_along_dim(
input: &SharedPtr<CrossTensor>,
indices: &SharedPtr<CrossTensor>,
dim: OptionalInt,
out: TensorOut
) -> Result<SharedPtr<CrossTensor>>;
9 changes: 9 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,12 @@ nif_impl!(
input: TensorStruct<'a>,
indices: TensorStruct<'a>
);

nif_impl!(
take_along_dim,
TensorStruct<'a>,
input: TensorStruct<'a>,
indices: TensorStruct<'a>,
dim: OptionalInt,
out: TensorOut
);
26 changes: 26 additions & 0 deletions test/tensor/manipulation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -854,4 +854,30 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
out = ExTorch.take(input, ExTorch.tensor([1, 5, 6], dtype: :int64))
assert ExTorch.allclose(out, expected)
end

test "take_along_dim/2" do
input = ExTorch.rand({3, 4})
expected = ExTorch.max(input)
indices = ExTorch.argmax(input)

out = ExTorch.take_along_dim(input, indices)
assert ExTorch.allclose(out, expected)
end

test "take_along_dim/3" do
input = ExTorch.rand({3, 4})
{expected, indices} = ExTorch.sort(input, -1)

out = ExTorch.take_along_dim(input, indices, -1)
assert ExTorch.allclose(out, expected)
end

test "take_along_dim/4" do
input = ExTorch.rand({3, 4})
{expected, indices} = ExTorch.min(input, 0, keepdim: true)

out = ExTorch.empty_like(expected)
ExTorch.take_along_dim(input, indices, 0, out)
assert ExTorch.allclose(out, expected)
end
end

0 comments on commit 7a8c3c2

Please sign in to comment.