Skip to content

Commit

Permalink
Merge pull request #87 from andfoy/add_vstack
Browse files Browse the repository at this point in the history
Add vstack
  • Loading branch information
andfoy authored Nov 10, 2023
2 parents 38b7d14 + 9b226c9 commit 55fa5bb
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 0 deletions.
49 changes: 49 additions & 0 deletions lib/extorch/native/tensor/ops/manipulation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -946,5 +946,54 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do
"""
@spec permute(ExTorch.Tensor.t(), tuple() | [integer()]) :: ExTorch.Tensor.t()
defbinding(permute(input, dims))

@doc """
Stack tensors in sequence vertically (row wise).
This is equivalent to concatenation along the first axis after all 1-D tensors
have been reshaped by `ExTorch.atleast_2d/1`.
## Arguments
- `tensors` (`[ExTorch.Tensor.t()] | tuple()`) - sequence of tensors to concatenate.
## Optional arguments
- out (`ExTorch.Tensor | nil`) - an optional pre-allocated tensor used to
store the output result. Default: `nil`
## Examples
iex> a = ExTorch.tensor([1, 2, 3])
iex> b = ExTorch.tensor([4, 5, 6])
iex> ExTorch.vstack({a, b})
#Tensor<
[[1, 2, 3],
[4, 5, 6]]
[size: {2, 3}, dtype: :byte, device: :cpu, requires_grad: false]>
iex> a = ExTorch.tensor([[1],[2],[3]])
#Tensor<
[[1],
[2],
[3]]
[size: {3, 1}, dtype: :byte, device: :cpu, requires_grad: false]>
iex> b = ExTorch.tensor([[4],[5],[6]])
#Tensor<
[[4],
[5],
[6]]
[size: {3, 1}, dtype: :byte, device: :cpu, requires_grad: false]>
iex> ExTorch.vstack([a, b])
#Tensor<
[[1],
[2],
[3],
[4],
[5],
[6]]
[size: {6, 1}, dtype: :byte, device: :cpu, requires_grad: false]>
"""
@spec vstack([ExTorch.Tensor.t()] | tuple(), ExTorch.Tensor.t() | nil) :: ExTorch.Tensor.t()
defbinding(vstack(tensors, out \\ nil), fn_aliases: [:row_stack])
end
end
2 changes: 2 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,5 @@ TensorTuple nonzero(
std::shared_ptr<CrossTensor> permute(
const std::shared_ptr<CrossTensor> &input,
rust::Vec<int64_t> dims);

std::shared_ptr<CrossTensor> vstack(TensorList tensor_list, TensorOut opt_out);
14 changes: 14 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,17 @@ std::shared_ptr<CrossTensor> permute(

return std::make_shared<CrossTensor>(std::move(out_tensor));
}

std::shared_ptr<CrossTensor> vstack(TensorList tensor_list, TensorOut opt_out) {
std::vector<CrossTensor> tensor_vec = unpack_tensor_list(tensor_list);
CrossTensor out_tensor;

if(opt_out.used) {
out_tensor = *opt_out.tensor.get();
out_tensor = torch::vstack_out(out_tensor, tensor_vec);
} else {
out_tensor = torch::vstack(tensor_vec);
}

return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ rustler::init!(
narrow_copy,
nonzero,
permute,
vstack,

// Tensor comparing operations
allclose,
Expand Down
3 changes: 3 additions & 0 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,6 @@ fn permute(
input: &SharedPtr<CrossTensor>,
dims: Vec<i64>
) -> Result<SharedPtr<CrossTensor>>;

/// Creates a new tensor by stacking vertically a sequence of tensors.
fn vstack(tensors: TensorList, out: TensorOut) -> Result<SharedPtr<CrossTensor>>;
7 changes: 7 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,10 @@ nif_impl!(
input: TensorStruct<'a>,
dims: Size
);

nif_impl!(
vstack,
TensorStruct<'a>,
tensors: TensorList,
out: TensorOut
);
15 changes: 15 additions & 0 deletions test/tensor/manipulation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,19 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
out = ExTorch.permute(input, {2, -1, 0, 1})
assert out.size == {5, 2, 3, 4}
end

test "vstack/1" do
input = ExTorch.rand({5, 10, 2})
parts = [input[0..3], input[3..5]]
out = ExTorch.vstack(parts)
assert ExTorch.allclose(out, input)
end

test "vstack/2" do
input = ExTorch.rand({5, 10, 2})
parts = [input[0..3], input[3..5]]
out = ExTorch.empty_like(input)
ExTorch.vstack(parts, out)
assert ExTorch.allclose(out, input)
end
end

0 comments on commit 55fa5bb

Please sign in to comment.