Skip to content

Commit

Permalink
now the only TODOs in the ragged.array class are unimplemented linear…
Browse files Browse the repository at this point in the history
… algebra functions
  • Loading branch information
jpivarski committed Jan 15, 2024
1 parent 3ff558f commit 76a0591
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import copy as copy_lib
import enum
import numbers
from collections.abc import Iterator
Expand Down Expand Up @@ -232,8 +233,8 @@ def __init__(
else:
self._device = "cuda"

if copy is not None:
raise NotImplementedError("TODO 1") # noqa: EM101
if copy and isinstance(self._impl, ak.Array):
self._impl = copy_lib.deepcopy(self._impl)

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -867,7 +868,8 @@ def __setitem__(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__setitem__.html
"""

raise NotImplementedError("TODO 31") # noqa: EM101
msg = "ragged.array is an immutable type; its values cannot be assigned to"
raise TypeError(msg)

def __sub__(self, other: int | float | array, /) -> array:
"""
Expand Down Expand Up @@ -930,26 +932,42 @@ def to_device(self, device: Device, /, *, stream: None | int | Any = None) -> ar
main memory; if `"cuda"`, the array is backed by CuPy and
resides in CUDA global memory.
stream: CuPy Stream object (https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html)
for `device="cuda"`.
for `device="cuda"`. Ignored if output `device` is `"cpu"`. If
this argument is an integer, it is interpreted as the pointer
address of a `cudaStream_t` object.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html
"""

if isinstance(stream, numbers.Integral):
cp = _import.cupy()
stream = cp.cuda.ExternalStream(stream)

if stream is not None:
t = type(stream)
if not t.__module__.startswith("cupy.") or t.__name__ != "Stream":
msg = f"stream object must be a cupy.cuda.Stream, not {stream!r}"
raise TypeError(msg)

if isinstance(self._impl, ak.Array):
if device != ak.backend(self._impl):
if stream is not None:
raise NotImplementedError("TODO 124") # noqa: EM101
impl = ak.to_backend(self._impl, device)
with stream:
impl = ak.to_backend(self._impl, device)
else:
impl = ak.to_backend(self._impl, device)
else:
impl = self._impl

elif isinstance(self._impl, np.ndarray):
# self._impl is a NumPy 0-dimensional array
if device == "cuda":
if stream is not None:
raise NotImplementedError("TODO 125") # noqa: EM101
cp = _import.cupy()
impl = cp.array(self._impl)
if stream is not None:
with stream:
impl = cp.array(self._impl)
else:
impl = cp.array(self._impl)
else:
impl = self._impl

Expand Down

0 comments on commit 76a0591

Please sign in to comment.