Skip to content

Commit

Permalink
Merge pull request #618 from mrava87/bug-restrictioncupy
Browse files Browse the repository at this point in the history
bug: fix problem in Restriction when passing iava as cupy array
  • Loading branch information
mrava87 authored Oct 26, 2024
2 parents bcd4da2 + 71dc191 commit 6d6887a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions pylops/basicoperators/restriction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_array_module,
get_normalize_axis_index,
inplace_set,
to_cupy_conditional,
to_numpy,
)
from pylops.utils.typing import DTypeLike, InputDimsLike, IntNDArray, NDArray

Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(
# explicitly create a list of indices in the n-dimensional
# model space which will be used in _rmatvec to place the input
if ncp != np:
self.iavamask = _compute_iavamask(self.dims, axis, iava, ncp)
self.iavamask = _compute_iavamask(self.dims, axis, to_numpy(iava), ncp)
self.inplace = inplace
self.axis = axis
self.iavareshape = iavareshape
Expand All @@ -173,7 +173,6 @@ def _rmatvec(self, x: NDArray) -> NDArray:
)
else:
if not hasattr(self, "iavamask"):
self.iava = to_cupy_conditional(x, self.iava)
self.iavamask = _compute_iavamask(self.dims, self.axis, self.iava, ncp)
y = ncp.zeros(int(self.shape[-1]), dtype=self.dtype)
y = inplace_set(x.ravel(), y, self.iavamask)
Expand Down

0 comments on commit 6d6887a

Please sign in to comment.