diff --git a/lib/extorch/native/tensor/ops/indexing.ex b/lib/extorch/native/tensor/ops/indexing.ex index 3ff52a0..cb8b3e3 100644 --- a/lib/extorch/native/tensor/ops/indexing.ex +++ b/lib/extorch/native/tensor/ops/indexing.ex @@ -563,5 +563,53 @@ defmodule ExTorch.Native.Tensor.Ops.Indexing do @spec masked_select(ExTorch.Tensor.t(), ExTorch.Tensor.t(), ExTorch.Tensor.t() | nil) :: ExTorch.Tensor.t() defbinding(masked_select(input, mask, out \\ nil)) + + @doc """ + Slices the `input` tensor along the selected dimension at the given `index`. + This function returns a view of the original tensor with the given dimension removed. + + ## Arguments + - `input` (`ExTorch.Tensor`) - the input tensor. + - `dim` (`integer`) - the dimension to slice. + - `index` (`integer`) - the index to select. + + ## Notes + `ExTorch.select/3` is equivalent to slicing. For example, `ExTorch.select(0, index)` + is equivalent to `tensor[index]` and `ExTorch.select(2, index)` is equivalent to + `tensor[{:::, :::, index}]`. + + ## Examples + iex> a = ExTorch.arange(2 * 3 * 4) |> ExTorch.reshape({2, 3, 4}) + #Tensor< + [[[ 0.0000, 1.0000, 2.0000, 3.0000], + [ 4.0000, 5.0000, 6.0000, 7.0000], + [ 8.0000, 9.0000, 10.0000, 11.0000]], + + [[12.0000, 13.0000, 14.0000, 15.0000], + [16.0000, 17.0000, 18.0000, 19.0000], + [20.0000, 21.0000, 22.0000, 23.0000]]] + [size: {2, 3, 4}, dtype: :float, device: :cpu, requires_grad: false]> + + iex> ExTorch.select(a, 0, 1) + #Tensor< + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]] + [size: {3, 4}, dtype: :float, device: :cpu, requires_grad: false]> + + iex> ExTorch.select(a, 1, 0) + #Tensor< + [[ 0.0000, 1.0000, 2.0000, 3.0000], + [12.0000, 13.0000, 14.0000, 15.0000]] + [size: {2, 4}, dtype: :float, device: :cpu, requires_grad: false]> + + iex> ExTorch.select(a, 2, 2) + #Tensor< + [[ 2., 6., 10.], + [14., 18., 22.]] + [size: {2, 3}, dtype: :float, device: :cpu, requires_grad: false]> + """ + @spec select(ExTorch.Tensor.t(), integer(), integer()) :: ExTorch.Tensor.t() + defbinding(select(input, dim, index)) end end diff --git a/native/extorch/include/manipulation.h b/native/extorch/include/manipulation.h index f3751fb..64ccf47 100644 --- a/native/extorch/include/manipulation.h +++ b/native/extorch/include/manipulation.h @@ -111,3 +111,8 @@ std::shared_ptr permute( rust::Vec dims); std::shared_ptr vstack(TensorList tensor_list, TensorOut opt_out); + +std::shared_ptr select( + const std::shared_ptr &input, + int64_t dim, + int64_t index); diff --git a/native/extorch/src/csrc/manipulation.cc b/native/extorch/src/csrc/manipulation.cc index 2a195e7..73d8f2d 100644 --- a/native/extorch/src/csrc/manipulation.cc +++ b/native/extorch/src/csrc/manipulation.cc @@ -515,3 +515,14 @@ std::shared_ptr vstack(TensorList tensor_list, TensorOut opt_out) { return std::make_shared(std::move(out_tensor)); } + +std::shared_ptr select( + const std::shared_ptr &input, + int64_t dim, + int64_t index) { + + CrossTensor out_tensor; + CrossTensor in_tensor = *input.get(); + out_tensor = torch::select(in_tensor, dim, index); + return std::make_shared(std::move(out_tensor)); +} diff --git a/native/extorch/src/lib.rs b/native/extorch/src/lib.rs index 5b8bf94..f3fbda7 100644 --- a/native/extorch/src/lib.rs +++ b/native/extorch/src/lib.rs @@ -115,6 +115,7 @@ rustler::init!( nonzero, permute, vstack, + select, // Tensor comparing operations allclose, diff --git a/native/extorch/src/native/tensor/ops.rs.in b/native/extorch/src/native/tensor/ops.rs.in index 865a73b..f8bf1a6 100644 --- a/native/extorch/src/native/tensor/ops.rs.in +++ b/native/extorch/src/native/tensor/ops.rs.in @@ -155,3 +155,10 @@ fn permute( /// Creates a new tensor by stacking vertically a sequence of tensors. fn vstack(tensors: TensorList, out: TensorOut) -> Result>; + +/// Index a tensor in a given dimension. +fn select( + input: &SharedPtr, + dim: i64, + index_param: i64 +) -> Result>; diff --git a/native/extorch/src/nifs/tensor_ops.rs b/native/extorch/src/nifs/tensor_ops.rs index 1608b71..0427c31 100644 --- a/native/extorch/src/nifs/tensor_ops.rs +++ b/native/extorch/src/nifs/tensor_ops.rs @@ -190,3 +190,11 @@ nif_impl!( tensors: TensorList, out: TensorOut ); + +nif_impl!( + select, + TensorStruct<'a>, + input: TensorStruct<'a>, + dim: i64, + index_param: i64 +); diff --git a/test/tensor/indexing_test.exs b/test/tensor/indexing_test.exs index 091c3e1..d6d2a91 100644 --- a/test/tensor/indexing_test.exs +++ b/test/tensor/indexing_test.exs @@ -405,4 +405,11 @@ defmodule ExTorchTest.Tensor.IndexingTest do out = ExTorch.narrow(input, 1, ExTorch.tensor(1), 2) assert ExTorch.allclose(out, expected) end + + test "select/3" do + input = ExTorch.rand({5, 10, 2}) + expected = input[{:"::", 4}] + out = ExTorch.select(input, 1, 4) + assert ExTorch.allclose(out, expected) + end end