From 7a8c3c2a2cb877a0532efea8825b961c0bb3a190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Andr=C3=A9s=20Margffoy=20Tuay?= Date: Mon, 4 Dec 2023 14:01:22 -0500 Subject: [PATCH] Add take_along_dim --- lib/extorch/native/tensor/ops/manipulation.ex | 49 +++++++++++++++++++ native/extorch/include/manipulation.h | 6 +++ native/extorch/src/csrc/manipulation.cc | 24 +++++++++ native/extorch/src/lib.rs | 1 + native/extorch/src/native/tensor/ops.rs.in | 8 +++ native/extorch/src/nifs/tensor_ops.rs | 9 ++++ test/tensor/manipulation_test.exs | 26 ++++++++++ 7 files changed, 123 insertions(+) diff --git a/lib/extorch/native/tensor/ops/manipulation.ex b/lib/extorch/native/tensor/ops/manipulation.ex index ec9d787..4fd5961 100644 --- a/lib/extorch/native/tensor/ops/manipulation.ex +++ b/lib/extorch/native/tensor/ops/manipulation.ex @@ -1673,5 +1673,54 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do """ @spec take(ExTorch.Tensor.t(), ExTorch.Tensor.t()) :: ExTorch.Tensor.t() defbinding(take(input, indices)) + + @doc """ + Selects values from `input` at the 1-dimensional indices from `indices` along the given `dim`. + + Functions that return indices along a dimension, like `ExTorch.argmax/3` and `ExTorch.argmin/3`, + are designed to work with this function. See the examples below. + + ## Arguments + - `input` (`ExTorch.Tensor`) - the input tensor. + - `indices` (`ExTorch.Tensor`) - the indices into `input`. It must have either `:long` or `:int64` dtype. + + ## Optional arguments + - `dim` (`integer`) - dimension to select along. If `nil`, then it will index all + the dimensions as a single one. Default: `nil`. + - `out` (`ExTorch.Tensor` or `nil`) - an optional pre-allocated tensor used to store the + output. Default: `nil` + + ## Notes + This function is similar to NumPy’s `take_along_axis`. See also `ExTorch.gather/5`. + + ## Examples + iex> t = ExTorch.tensor([[10, 30, 20], [60, 40, 50]], dtype: :long) + #Tensor< + [[10, 30, 20], + [60, 40, 50]] + [size: {2, 3}, dtype: :long, device: :cpu, requires_grad: false]> + iex> max_idx = ExTorch.argmax(t) + #Tensor< 3 [size: {}, dtype: :long, device: :cpu, requires_grad: false]> + iex> ExTorch.take_along_dim(t, max_idx) + #Tensor< [60] [size: {1}, dtype: :long, device: :cpu, requires_grad: false]> + + iex> sorted_idx = ExTorch.argsort(t, dim: 1) + #Tensor< + [[0, 2, 1], + [1, 2, 0]] + [size: {2, 3}, dtype: :long, device: :cpu, requires_grad: false]> + iex> ExTorch.take_along_dim(t, sorted_idx, dim: 1) + #Tensor< + [[10, 20, 30], + [40, 50, 60]] + [size: {2, 3}, dtype: :long, device: :cpu, requires_grad: false]> + """ + @spec take_along_dim( + ExTorch.Tensor.t(), + ExTorch.Tensor.t(), + integer() | nil, + ExTorch.Tensor.t() | nil + ) :: ExTorch.Tensor.t() + defbinding(take_along_dim(input, indices, dim \\ nil, out \\ nil)) end end diff --git a/native/extorch/include/manipulation.h b/native/extorch/include/manipulation.h index 47bdc35..eb32a4e 100644 --- a/native/extorch/include/manipulation.h +++ b/native/extorch/include/manipulation.h @@ -182,3 +182,9 @@ std::shared_ptr t(const std::shared_ptr &input); std::shared_ptr take( const std::shared_ptr &input, const std::shared_ptr &indices); + +std::shared_ptr take_along_dim( + const std::shared_ptr &input, + const std::shared_ptr &indices, + OptionalInt dim, + TensorOut out); diff --git a/native/extorch/src/csrc/manipulation.cc b/native/extorch/src/csrc/manipulation.cc index 9df6cd4..1ba5fa0 100644 --- a/native/extorch/src/csrc/manipulation.cc +++ b/native/extorch/src/csrc/manipulation.cc @@ -753,3 +753,27 @@ std::shared_ptr take( out_tensor = torch::take(in_tensor, indices_tensor); return std::make_shared(std::move(out_tensor)); } + +std::shared_ptr take_along_dim( + const std::shared_ptr &input, + const std::shared_ptr &indices, + OptionalInt dim, + TensorOut out) { + + CrossTensor out_tensor; + CrossTensor in_tensor = *input.get(); + CrossTensor indices_tensor = *indices.get(); + + torch::optional opt_dim = torch::nullopt; + if(dim.used) { + opt_dim = dim.value; + } + + if(out.used) { + out_tensor = *out.tensor.get(); + out_tensor = torch::take_along_dim_out(out_tensor, in_tensor, indices_tensor, opt_dim); + } else { + out_tensor = torch::take_along_dim(in_tensor, indices_tensor, opt_dim); + } + return std::make_shared(std::move(out_tensor)); +} diff --git a/native/extorch/src/lib.rs b/native/extorch/src/lib.rs index 6f65053..e1029ac 100644 --- a/native/extorch/src/lib.rs +++ b/native/extorch/src/lib.rs @@ -127,6 +127,7 @@ rustler::init!( stack, t, take, + take_along_dim, // Tensor comparing operations allclose, diff --git a/native/extorch/src/native/tensor/ops.rs.in b/native/extorch/src/native/tensor/ops.rs.in index ee6fe39..5b7781f 100644 --- a/native/extorch/src/native/tensor/ops.rs.in +++ b/native/extorch/src/native/tensor/ops.rs.in @@ -239,3 +239,11 @@ fn take( input: &SharedPtr, indices: &SharedPtr, ) -> Result>; + +/// Index a tensor alongside a dimension. +fn take_along_dim( + input: &SharedPtr, + indices: &SharedPtr, + dim: OptionalInt, + out: TensorOut +) -> Result>; diff --git a/native/extorch/src/nifs/tensor_ops.rs b/native/extorch/src/nifs/tensor_ops.rs index ca330ee..21b6e61 100644 --- a/native/extorch/src/nifs/tensor_ops.rs +++ b/native/extorch/src/nifs/tensor_ops.rs @@ -296,3 +296,12 @@ nif_impl!( input: TensorStruct<'a>, indices: TensorStruct<'a> ); + +nif_impl!( + take_along_dim, + TensorStruct<'a>, + input: TensorStruct<'a>, + indices: TensorStruct<'a>, + dim: OptionalInt, + out: TensorOut +); diff --git a/test/tensor/manipulation_test.exs b/test/tensor/manipulation_test.exs index 553b58d..f6aea6e 100644 --- a/test/tensor/manipulation_test.exs +++ b/test/tensor/manipulation_test.exs @@ -854,4 +854,30 @@ defmodule ExTorchTest.Tensor.ManipulationTest do out = ExTorch.take(input, ExTorch.tensor([1, 5, 6], dtype: :int64)) assert ExTorch.allclose(out, expected) end + + test "take_along_dim/2" do + input = ExTorch.rand({3, 4}) + expected = ExTorch.max(input) + indices = ExTorch.argmax(input) + + out = ExTorch.take_along_dim(input, indices) + assert ExTorch.allclose(out, expected) + end + + test "take_along_dim/3" do + input = ExTorch.rand({3, 4}) + {expected, indices} = ExTorch.sort(input, -1) + + out = ExTorch.take_along_dim(input, indices, -1) + assert ExTorch.allclose(out, expected) + end + + test "take_along_dim/4" do + input = ExTorch.rand({3, 4}) + {expected, indices} = ExTorch.min(input, 0, keepdim: true) + + out = ExTorch.empty_like(expected) + ExTorch.take_along_dim(input, indices, 0, out) + assert ExTorch.allclose(out, expected) + end end