Skip to content

Commit

Permalink
Workaround for dpnp.linalg.qr() to run on CUDA (#2265)
Browse files Browse the repository at this point in the history
This PR suggests adding a workaround like waiting for host task after
calling `geqrf` to avoid a race condition due to an issue in oneMath
uxlfoundation/oneMath#626

Also updates tests by removing old skips and adds `test_qr_large` in
`TestQr`
  • Loading branch information
vlad-perevezentsev authored Jan 20, 2025
1 parent 9ad1bb5 commit 807179a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 27 deletions.
18 changes: 16 additions & 2 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,14 @@ def _batched_qr(a, mode="reduced"):
batch_size,
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, geqrf_ev)

# w/a to avoid raice conditional on CUDA during multiple runs
# TODO: Remove it ones the OneMath issue is resolved
# https://github.com/uxlfoundation/oneMath/issues/626
if dpnp.is_cuda_backend(a_sycl_queue):
ht_ev.wait()
else:
_manager.add_event_pair(ht_ev, geqrf_ev)

if mode in ["r", "raw"]:
if mode == "r":
Expand Down Expand Up @@ -2468,7 +2475,14 @@ def dpnp_qr(a, mode="reduced"):
ht_ev, geqrf_ev = li._geqrf(
a_sycl_queue, a_t.get_array(), tau_h.get_array(), depends=[copy_ev]
)
_manager.add_event_pair(ht_ev, geqrf_ev)

# w/a to avoid raice conditional on CUDA during multiple runs
# TODO: Remove it ones the OneMath issue is resolved
# https://github.com/uxlfoundation/oneMath/issues/626
if dpnp.is_cuda_backend(a_sycl_queue):
ht_ev.wait()
else:
_manager.add_event_pair(ht_ev, geqrf_ev)

if mode in ["r", "raw"]:
if mode == "r":
Expand Down
54 changes: 36 additions & 18 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,12 +2380,6 @@ class TestQr:
)
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
def test_qr(self, dtype, shape, mode):
if (
is_cuda_device()
and mode in ["complete", "reduced"]
and shape in [(16, 16), (2, 2, 4)]
):
pytest.skip("SAT-7589")
a = generate_random_numpy_array(shape, dtype, seed_value=81)
ia = dpnp.array(a)

Expand All @@ -2398,24 +2392,48 @@ def test_qr(self, dtype, shape, mode):

# check decomposition
if mode in ("complete", "reduced"):
if a.ndim == 2:
assert_almost_equal(
dpnp.dot(dpnp_q, dpnp_r),
a,
decimal=5,
)
else: # a.ndim > 2
assert_almost_equal(
dpnp.matmul(dpnp_q, dpnp_r),
a,
decimal=5,
)
assert_almost_equal(
dpnp.matmul(dpnp_q, dpnp_r),
a,
decimal=5,
)
else: # mode=="raw"
assert_dtype_allclose(dpnp_q, np_q)

if mode in ("raw", "r"):
assert_dtype_allclose(dpnp_r, np_r)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"shape",
[(32, 32), (8, 16, 16)],
ids=[
"(32, 32)",
"(8, 16, 16)",
],
)
@pytest.mark.parametrize("mode", ["r", "raw", "complete", "reduced"])
def test_qr_large(self, dtype, shape, mode):
a = generate_random_numpy_array(shape, dtype, seed_value=81)
ia = dpnp.array(a)
if mode == "r":
np_r = numpy.linalg.qr(a, mode)
dpnp_r = dpnp.linalg.qr(ia, mode)
else:
np_q, np_r = numpy.linalg.qr(a, mode)
dpnp_q, dpnp_r = dpnp.linalg.qr(ia, mode)
# check decomposition
if mode in ("complete", "reduced"):
assert_almost_equal(
dpnp.matmul(dpnp_q, dpnp_r),
a,
decimal=5,
)
else: # mode=="raw"
assert_allclose(np_q, dpnp_q, atol=1e-4)
if mode in ("raw", "r"):
assert_allclose(np_r, dpnp_r, atol=1e-4)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"shape",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,7 @@ def test_decomposition(self, dtype):
class TestQRDecomposition(unittest.TestCase):

@testing.for_dtypes("fdFD")
# skip cases with 'complete' and 'reduce' modes on CUDA (SAT-7611)
def check_mode(self, array, mode, dtype):
if (
is_cuda_device()
and array.size > 0
and mode in ["complete", "reduced"]
):
return
a_cpu = numpy.asarray(array, dtype=dtype)
a_gpu = cupy.asarray(array, dtype=dtype)
result_gpu = cupy.linalg.qr(a_gpu, mode=mode)
Expand Down

0 comments on commit 807179a

Please sign in to comment.