Skip to content

Commit

Permalink
Add diagonal_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
andfoy committed Nov 18, 2023
1 parent 4006088 commit e488310
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 0 deletions.
59 changes: 59 additions & 0 deletions lib/extorch/native/tensor/ops/manipulation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1092,5 +1092,64 @@ defmodule ExTorch.Native.Tensor.Ops.Manipulation do
false -> src
end
)

@doc """
Embeds the values of the `src` tensor into `input` along the diagonal elements of `input`,
with respect to `dim1` and `dim2`.
This function returns a tensor with fresh storage; it does not return a view.
The argument `offset` controls which diagonal to consider:
* If `offset = 0`, it is the main diagonal.
* If `offset > 0`, it is above the main diagonal.
* If `offset < 0`, it is below the main diagonal.
## Arguments
- `input` (`ExTorch.Tensor`) - the input tensor. Must be at least 2-dimensional.
- `src` (`ExTorch.Tensor`) - the tensor to embed into `input`.
- `offset` (`integer`) - which diagonal to consider. Default: 0 (main diagonal).
- `dim1` (`integer`) - first dimension with respect to which to take diagonal. Default: 0.
- `dim2` (`integer`) - second dimension with respect to which to take diagonal. Default: 1.
## Optional arguments
- `out` (`ExTorch.Tensor` or `nil`) - an optional pre-allocated tensor used to
store the output result. Default: `nil`
## Notes
`src` must be of the proper size in order to be embedded into `input`. Specifically, it should have
the same shape as `ExTorch.diagonal(input, offset, dim1, dim2)`
## Examples
iex> a = ExTorch.zeros({3, 3})
#Tensor<
[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]]
[size: {3, 3}, dtype: :float, device: :cpu, requires_grad: false]>
iex> ExTorch.diagonal_scatter(a, ExTorch.ones(3), 0)
#Tensor<
[[1.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000]]
[size: {3, 3}, dtype: :float, device: :cpu, requires_grad: false]>
iex> ExTorch.diagonal_scatter(a, ExTorch.ones(2), 1)
#Tensor<
[[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000],
[0.0000, 0.0000, 0.0000]]
[size: {3, 3}, dtype: :float, device: :cpu, requires_grad: false]>
"""
@spec diagonal_scatter(
ExTorch.Tensor.t(),
ExTorch.Tensor.t(),
integer(),
integer(),
integer(),
ExTorch.Tensor.t() | nil
) :: ExTorch.Tensor.t()
defbinding(diagonal_scatter(input, src, offset \\ 0, dim1 \\ 0, dim2 \\ 1, out \\ nil))
end
end
8 changes: 8 additions & 0 deletions native/extorch/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,11 @@ std::shared_ptr<CrossTensor> scatter(
const std::shared_ptr<CrossTensor> &src,
TensorOut out,
const bool inplace);

std::shared_ptr<CrossTensor> diagonal_scatter(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &src,
int64_t offset,
int64_t dim1,
int64_t dim2,
TensorOut out);
24 changes: 24 additions & 0 deletions native/extorch/src/csrc/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,27 @@ std::shared_ptr<CrossTensor> scatter(

return std::make_shared<CrossTensor>(std::move(out_tensor));
}

std::shared_ptr<CrossTensor> diagonal_scatter(
const std::shared_ptr<CrossTensor> &input,
const std::shared_ptr<CrossTensor> &src,
int64_t offset,
int64_t dim1,
int64_t dim2,
TensorOut out) {

CrossTensor out_tensor;
CrossTensor in_tensor = *input.get();
CrossTensor src_tensor = *src.get();

if(out.used) {
out_tensor = *out.tensor.get();
out_tensor = torch::diagonal_scatter_out(
out_tensor, in_tensor, src_tensor, offset, dim1, dim2);
} else {
out_tensor = torch::diagonal_scatter(
in_tensor, src_tensor, offset, dim1, dim2);
}

return std::make_shared<CrossTensor>(std::move(out_tensor));
}
1 change: 1 addition & 0 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ rustler::init!(
vstack,
select,
scatter,
diagonal_scatter,

// Tensor comparing operations
allclose,
Expand Down
10 changes: 10 additions & 0 deletions native/extorch/src/native/tensor/ops.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,13 @@ fn scatter(
out: TensorOut,
inplace: bool,
) -> Result<SharedPtr<CrossTensor>>;

/// Embeds the values of the `src` tensor into `input` along the diagonal elements of `input`.
fn diagonal_scatter(
input: &SharedPtr<CrossTensor>,
src: &SharedPtr<CrossTensor>,
offset: i64,
dim1: i64,
dim2: i64,
out: TensorOut
) -> Result<SharedPtr<CrossTensor>>;
11 changes: 11 additions & 0 deletions native/extorch/src/nifs/tensor_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,14 @@ nif_impl!(
out: TensorOut,
inplace: bool
);

nif_impl!(
diagonal_scatter,
TensorStruct<'a>,
input: TensorStruct<'a>,
src: TensorStruct<'a>,
offset: i64,
dim1: i64,
dim2: i64,
out: TensorOut
);
70 changes: 70 additions & 0 deletions test/tensor/manipulation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
input = ExTorch.zeros({3, 5}, dtype: src.dtype)

out = ExTorch.empty_like(input)

expected =
ExTorch.tensor([
[1.0000, 2.0000, 3.0000, 0.0000, 0.0000],
Expand Down Expand Up @@ -447,4 +448,73 @@ defmodule ExTorchTest.Tensor.ManipulationTest do
ExTorch.scatter(input, 0, index, src, nil, true)
assert ExTorch.allclose(input, expected)
end

test "diagonal_scatter/2" do
input = ExTorch.zeros({3, 3})
expected = ExTorch.eye(3)
out = ExTorch.diagonal_scatter(input, ExTorch.ones(3))
assert ExTorch.allclose(out, expected)
end

test "diagonal_scatter/3" do
input = ExTorch.zeros({3, 3})
expected = ExTorch.eye(3)
out = ExTorch.diagonal_scatter(input, ExTorch.ones(3), 0)
assert ExTorch.allclose(out, expected)
end

test "diagonal_scatter/3 with kwargs" do
input = ExTorch.zeros({3, 3})

expected =
ExTorch.tensor([
[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000],
[0.0000, 0.0000, 0.0000]
])

out = ExTorch.diagonal_scatter(input, ExTorch.ones(2), offset: 1)
assert ExTorch.allclose(out, expected)
end

test "diagonal_scatter/4" do
input = ExTorch.zeros({3, 3, 3})
base = ExTorch.eye(3) |> ExTorch.unsqueeze(0)
expected = ExTorch.cat([base, base, base], 0)

out = ExTorch.diagonal_scatter(input, ExTorch.ones({3, 3}), 0, 2)
assert ExTorch.allclose(out, expected)
end

test "diagonal_scatter/4 with kwargs" do
input = ExTorch.zeros({3, 3, 3})

advanced_index = [
[0, 0, 0, 1, 1, 1, 2, 2, 2],
[0, 1, 2, 0, 1, 2, 0, 1, 2],
[0, 0, 0, 1, 1, 1, 2, 2, 2]
]

expected = ExTorch.index_put(input, advanced_index, 1.0)

out = ExTorch.diagonal_scatter(input, ExTorch.ones({3, 3}), 0, dim2: -1)
assert ExTorch.allclose(out, expected)
end

test "diagonal_scatter/5" do
input = ExTorch.zeros({3, 3, 3})
expected = ExTorch.index_put(input, [[0, 1, 2], [0, 1, 2]], 1.0)

out = ExTorch.diagonal_scatter(input, ExTorch.ones({3, 3}), 0, 0, 1)
assert ExTorch.allclose(out, expected)
end

test "diagonal_scatter/6" do
input = ExTorch.zeros({3, 3, 3})
expected = ExTorch.index_put(input, [[0, 1, 2], [0, 1, 2]], 1.0)

out = ExTorch.empty_like(input)
ExTorch.diagonal_scatter(input, ExTorch.ones({3, 3}), 0, 0, 1, out)
assert ExTorch.allclose(out, expected)
end
end

0 comments on commit e488310

Please sign in to comment.