Skip to content

Commit

Permalink
Merge pull request #16 from andfoy/add_numel
Browse files Browse the repository at this point in the history
Add `numel` to ExTorch.Tensor
  • Loading branch information
andfoy authored Aug 30, 2023
2 parents 752056e + 6561837 commit 7f2e4a0
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 6 deletions.
14 changes: 14 additions & 0 deletions lib/extorch/native/tensor/info.ex
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,19 @@ defmodule ExTorch.Native.Tensor.Info do
"""
@spec to_list(ExTorch.Tensor.t()) :: list()
defbinding(to_list(tensor))

@doc """
Returns the total number of elements in the input tensor.
## Arguments
- `tensor` (`ExTorch.Tensor`): Input tensor.
## Examples
iex> x = ExTorch.empty({3, 4, 5})
iex> ExTorch.Tensor.numel(x)
60
"""
@spec numel(ExTorch.Tensor.t()) :: integer()
defbinding(numel(tensor))
end
end
3 changes: 2 additions & 1 deletion native/extorch/include/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ rust::String dtype(const std::shared_ptr<CrossTensor> &tensor);
Device device(const std::shared_ptr<CrossTensor> &tensor);
rust::String repr(const std::shared_ptr<CrossTensor> &tensor, const PrintOptions opts);
ScalarList to_list(const std::shared_ptr<CrossTensor> &tensor);
bool requires_grad(const std::shared_ptr<CrossTensor> &tensor);
bool requires_grad(const std::shared_ptr<CrossTensor> &tensor);
int64_t numel(const std::shared_ptr<CrossTensor> &tensor);
7 changes: 6 additions & 1 deletion native/extorch/src/csrc/info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,9 @@ ScalarList to_list(const std::shared_ptr<CrossTensor> &tensor) {
bool requires_grad(const std::shared_ptr<CrossTensor> &tensor) {
CrossTensor cross_tensor = *tensor.get();
return cross_tensor.requires_grad();
}
}

int64_t numel(const std::shared_ptr<CrossTensor> &tensor) {
CrossTensor cross_tensor = *tensor.get();
return cross_tensor.numel();
}
15 changes: 11 additions & 4 deletions native/extorch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@ rustler::init!(
"Elixir.ExTorch.Native",
[
add,

// Tensor information
repr,
size,
device,
dtype,
requires_grad,
numel,
to_list,
real,
imag,

// Tensor creation
empty,
zeros,
ones,
Expand All @@ -61,14 +69,13 @@ rustler::init!(
tensor,
complex,
polar,
to_list,
view_as_complex,

// Tensor manipulation
unsqueeze,
reshape,
index,
index_put,
real,
imag,
view_as_complex
],
load = load
);
3 changes: 3 additions & 0 deletions native/extorch/src/native/tensor/info.rs.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ fn repr(tensor: &SharedPtr<CrossTensor>, opts: PrintOptions) -> Result<String>;

/// Convert a tensor into a list
fn to_list(tensor: &SharedPtr<CrossTensor>) -> Result<ScalarList>;

/// Return the total number of elements of a tensor.
fn numel(tensor: &SharedPtr<CrossTensor>) -> Result<i64>;
1 change: 1 addition & 0 deletions native/extorch/src/nifs/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ nif_impl!(device, torch::Device, tensor: TensorStruct<'a>);
nif_impl!(dtype, AtomString, tensor: TensorStruct<'a>);
nif_impl!(to_list, torch::ScalarList, tensor: TensorStruct<'a>);
nif_impl!(requires_grad, bool, tensor: TensorStruct<'a>);
nif_impl!(numel, i64, tensor: TensorStruct<'a>);
// nif_impl!(repr, String, tensor => Tensor);
5 changes: 5 additions & 0 deletions test/tensor/info_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,9 @@ defmodule ExTorchTest.Tensor.InfoTest do
tensor = ExTorch.empty({2})
assert !ExTorch.Tensor.requires_grad(tensor)
end

test "numel/1" do
tensor = ExTorch.empty({3, 4, 5})
assert ExTorch.Tensor.numel(tensor) == 3 * 4 * 5
end
end

0 comments on commit 7f2e4a0

Please sign in to comment.