Skip to content

Commit

Permalink
Merge pull request #96 from andfoy/add_squeeze
Browse files Browse the repository at this point in the history
Add squeeze
  • Loading branch information
andfoy authored Dec 2, 2023
2 parents ffdb57b + 72cc29c commit 53d2f0b
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 14 deletions.
47 changes: 47 additions & 0 deletions lib/extorch/native/tensor/ops/manipulation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1508,5 +1508,52 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do
ExTorch.Tensor.t()
]
defbinding(split(tensor, split_size_or_sections, dim \\ 0))

@doc """
Returns a tensor with all specified dimensions of `input` of size 1 removed.
For example, if `input` is of shape: $\\(A \\times 1 \\times B \\times C \\times 1 \\times D\\)$ then
`ExTorch.squeeze(input)` will be of shape: $\\(A \\times B \\times C \\times D \\)$.
When `dim` is given, a squeeze operation is done only in the given dimension(s).
If `input` is of shape: $\\(A \\times 1 \\times B \\)$, `squeeze(input, 0)` leaves the tensor
unchanged, but `squeeze(input, 1)` will squeeze the tensor to the shape $\\(A \\times B \\)$.
## Arguments
- `input` (`ExTorch.Tensor`) - the input tensor.
## Optional arguments
- `dim` (`integer` or `tuple` or `[integer]` or `nil`) - the dimension(s) to squeeze from `input`.
If `nil`, then all singleton dimensions will be squeezed.
## Notes
1. The returned tensor shares the storage with the `input` tensor, so changing the
contents of one will change the contents of the other.
2. If the tensor has a batch dimension of size 1, then `squeeze(input)` will also
remove the batch dimension, which can lead to unexpected errors.
Consider specifying only the dims you wish to be squeezed.
## Examples
iex> a = ExTorch.empty({1, 3, 1, 4, 1, 5})
iex> a.size
{1, 3, 1, 4, 5}
# Squeeze all singleton dimensions
iex> b = ExTorch.squeeze(a)
iex> b.size
{3, 4, 5}
# Squeeze a particular dimension
iex> b = ExTorch.squeeze(a, -2)
iex> b.size
{1, 3, 1, 4, 5}
# Squeeze particular dimensions
iex> b = ExTorch.squeeze(a, {2, 4})
iex> b.size
{1, 3, 4, 5}
"""
@spec squeeze(ExTorch.Tensor.t(), integer() | tuple() | [integer()] | nil) :: ExTorch.Tensor.t()
defbinding(squeeze(input, dim \\ nil))
end
end
4 changes: 4 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,7 @@ std::shared_ptr<CrossTensor> scatter_reduce(
TensorList split(
const std::shared_ptr<CrossTensor> &input, IntListOrInt indices_or_sections,
int64_t dim);

std::shared_ptr<CrossTensor> squeeze(
const std::shared_ptr<CrossTensor> &input,
rust::Vec<int64_t> dims);
13 changes: 13 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,16 @@ TensorList split(
}
return pack_tensor_list(seq);
}

std::shared_ptr<CrossTensor> squeeze(
const std::shared_ptr<CrossTensor> &input,
rust::Vec<int64_t> dims) {
CrossTensor out_tensor;
CrossTensor in_tensor = *input.get();
if(dims.size() == 0) {
out_tensor = torch::squeeze(in_tensor);
} else {
out_tensor = torch::squeeze(in_tensor, torch::IntArrayRef{dims.data(), dims.size()});
}
return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ rustler::init!(
scatter_add,
scatter_reduce,
split,
squeeze,

// Tensor comparing operations
allclose,
Expand Down
28 changes: 14 additions & 14 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,10 @@ fn narrow_copy(
) -> Result<SharedPtr<CrossTensor>>;

/// Retrieve the indices of all non-zero elements in a tensor.
fn nonzero(
input: &SharedPtr<CrossTensor>,
out: TensorOut,
as_tuple: bool
) -> Result<TensorTuple>;
fn nonzero(input: &SharedPtr<CrossTensor>, out: TensorOut, as_tuple: bool) -> Result<TensorTuple>;

/// Permute a tensor dimensions and return the result as a view.
fn permute(
input: &SharedPtr<CrossTensor>,
dims: Vec<i64>
) -> Result<SharedPtr<CrossTensor>>;
fn permute(input: &SharedPtr<CrossTensor>, dims: Vec<i64>) -> Result<SharedPtr<CrossTensor>>;

/// Creates a new tensor by stacking vertically a sequence of tensors.
fn vstack(tensors: TensorList, out: TensorOut) -> Result<SharedPtr<CrossTensor>>;
Expand All @@ -160,7 +153,7 @@ fn vstack(tensors: TensorList, out: TensorOut) -> Result<SharedPtr<CrossTensor>>
fn select(
input: &SharedPtr<CrossTensor>,
dim: i64,
index_param: i64
index_param: i64,
) -> Result<SharedPtr<CrossTensor>>;

/// Writes all values from the tensor `src` into `input` at the indices specified in the `index` tensor.
Expand All @@ -180,7 +173,7 @@ fn diagonal_scatter(
offset: i64,
dim1: i64,
dim2: i64,
out: TensorOut
out: TensorOut,
) -> Result<SharedPtr<CrossTensor>>;

/// Embeds the values of the `src` tensor into input at the given `index`.
Expand All @@ -189,7 +182,7 @@ fn select_scatter(
src: &SharedPtr<CrossTensor>,
dim: i64,
index: i64,
out: TensorOut
out: TensorOut,
) -> Result<SharedPtr<CrossTensor>>;

/// Embeds the values of the `src` tensor into `input` at the given dimension.
Expand All @@ -200,7 +193,7 @@ fn slice_scatter(
start: OptionalInt,
end: OptionalInt,
step: i64,
out: TensorOut
out: TensorOut,
) -> Result<SharedPtr<CrossTensor>>;

/// Adds all values from the tensor `src` into `input` at the indices specified in the `index` tensor.
Expand All @@ -226,4 +219,11 @@ fn scatter_reduce(
) -> Result<SharedPtr<CrossTensor>>;

/// Split a tensor of one or more dimensions across an axis according to indices_or_sections.
fn split(input: &SharedPtr<CrossTensor>, indices_or_sections: IntListOrInt, dim: i64) -> Result<TensorList>;
fn split(
input: &SharedPtr<CrossTensor>,
indices_or_sections: IntListOrInt,
dim: i64,
) -> Result<TensorList>;

/// Remove specified or all singleton dimensions from a input tensor.
fn squeeze(input: &SharedPtr<CrossTensor>, dims: Vec<i64>) -> Result<SharedPtr<CrossTensor>>;
7 changes: 7 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,10 @@ nif_impl!(
indices_or_sections: IntListOrInt,
dim: i64
);

nif_impl!(
squeeze,
TensorStruct<'a>,
input: TensorStruct<'a>,
dims: Size
);
21 changes: 21 additions & 0 deletions test/tensor/manipulation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -782,4 +782,25 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
|> Enum.zip(expected)
|> Enum.reduce(true, fn {o, e}, acc -> ExTorch.allclose(o, e) and acc end)
end

test "squeeze/1" do
input = ExTorch.empty({1, 3, 1, 10, 1})
expected = {3, 10}
out = ExTorch.squeeze(input)
assert out.size == expected
end

test "squeeze/2 with a single dimension" do
input = ExTorch.empty({1, 3, 1, 10, 1})
expected = {1, 3, 10, 1}
out = ExTorch.squeeze(input, 2)
assert out.size == expected
end

test "squeeze/2 with a dimension list" do
input = ExTorch.empty({1, 3, 1, 10, 1})
expected = {3, 10, 1}
out = ExTorch.squeeze(input, {0, 2})
assert out.size == expected
end
end

0 comments on commit 53d2f0b

Please sign in to comment.