Skip to content

Commit

Permalink
Add select
Browse files Browse the repository at this point in the history
  • Loading branch information
andfoy committed Nov 10, 2023
1 parent 55fa5bb commit 701ccfb
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 0 deletions.
48 changes: 48 additions & 0 deletions lib/extorch/native/tensor/ops/indexing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -563,5 +563,53 @@ defmodule ExTorch.Native.Tensor.Ops.Indexing do
@spec masked_select(ExTorch.Tensor.t(), ExTorch.Tensor.t(), ExTorch.Tensor.t() | nil) ::
ExTorch.Tensor.t()
defbinding(masked_select(input, mask, out \\ nil))

@doc """
Slices the `input` tensor along the selected dimension at the given `index`.
This function returns a view of the original tensor with the given dimension removed.
## Arguments
- `input` (`ExTorch.Tensor`) - the input tensor.
- `dim` (`integer`) - the dimension to slice.
- `index` (`integer`) - the index to select.
## Notes
`ExTorch.select/3` is equivalent to slicing. For example, `ExTorch.select(0, index)`
is equivalent to `tensor[index]` and `ExTorch.select(2, index)` is equivalent to
`tensor[{:::, :::, index}]`.
## Examples
iex> a = ExTorch.arange(2 * 3 * 4) |> ExTorch.reshape({2, 3, 4})
#Tensor<
[[[ 0.0000, 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000, 7.0000],
[ 8.0000, 9.0000, 10.0000, 11.0000]],
[[12.0000, 13.0000, 14.0000, 15.0000],
[16.0000, 17.0000, 18.0000, 19.0000],
[20.0000, 21.0000, 22.0000, 23.0000]]]
[size: {2, 3, 4}, dtype: :float, device: :cpu, requires_grad: false]>
iex> ExTorch.select(a, 0, 1)
#Tensor<
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]
[size: {3, 4}, dtype: :float, device: :cpu, requires_grad: false]>
iex> ExTorch.select(a, 1, 0)
#Tensor<
[[ 0.0000, 1.0000, 2.0000, 3.0000],
[12.0000, 13.0000, 14.0000, 15.0000]]
[size: {2, 4}, dtype: :float, device: :cpu, requires_grad: false]>
iex> ExTorch.select(a, 2, 2)
#Tensor<
[[ 2., 6., 10.],
[14., 18., 22.]]
[size: {2, 3}, dtype: :float, device: :cpu, requires_grad: false]>
"""
@spec select(ExTorch.Tensor.t(), integer(), integer()) :: ExTorch.Tensor.t()
defbinding(select(input, dim, index))
end
end
5 changes: 5 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,8 @@ std::shared_ptr<CrossTensor> permute(
rust::Vec<int64_t> dims);

std::shared_ptr<CrossTensor> vstack(TensorList tensor_list, TensorOut opt_out);

std::shared_ptr<CrossTensor> select(
const std::shared_ptr<CrossTensor> &input,
int64_t dim,
int64_t index);
11 changes: 11 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,14 @@ std::shared_ptr<CrossTensor> vstack(TensorList tensor_list, TensorOut opt_out) {

return std::make_shared<CrossTensor>(std::move(out_tensor));
}

std::shared_ptr<CrossTensor> select(
const std::shared_ptr<CrossTensor> &input,
int64_t dim,
int64_t index) {

CrossTensor out_tensor;
CrossTensor in_tensor = *input.get();
out_tensor = torch::select(in_tensor, dim, index);
return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ rustler::init!(
nonzero,
permute,
vstack,
select,

// Tensor comparing operations
allclose,
Expand Down
7 changes: 7 additions & 0 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,10 @@ fn permute(

/// Creates a new tensor by stacking vertically a sequence of tensors.
fn vstack(tensors: TensorList, out: TensorOut) -> Result<SharedPtr<CrossTensor>>;

/// Index a tensor in a given dimension.
fn select(
input: &SharedPtr<CrossTensor>,
dim: i64,
index_param: i64
) -> Result<SharedPtr<CrossTensor>>;
8 changes: 8 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,11 @@ nif_impl!(
tensors: TensorList,
out: TensorOut
);

nif_impl!(
select,
TensorStruct<'a>,
input: TensorStruct<'a>,
dim: i64,
index_param: i64
);
7 changes: 7 additions & 0 deletions test/tensor/indexing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,11 @@ defmodule ExTorchTest.Tensor.IndexingTest do
out = ExTorch.narrow(input, 1, ExTorch.tensor(1), 2)
assert ExTorch.allclose(out, expected)
end

test "select/3" do
input = ExTorch.rand({5, 10, 2})
expected = input[{:"::", 4}]
out = ExTorch.select(input, 1, 4)
assert ExTorch.allclose(out, expected)
end
end

0 comments on commit 701ccfb

Please sign in to comment.