Skip to content

Commit

Permalink
fix: ak.from_numpy should fail on zero-dimensional arrays. (#3161)
Browse files Browse the repository at this point in the history
* Fixing policy issue 1057

At this point it's a matter of consistency. Now *both* `ak.Array()` and `ak.from_numpy()` will throw a TypeError when passed either a python scalar, a numpy numeric value like `np.float64(5.2)`, **or** a numpy zero-dimensional-array like `np.array(5.2)`. But in the latter case there is an option to allow this, which is necessary for lots of internal functions, by providing a value for `primitive_policy` other than "error."

Prior to this patch the first test, ak.Array(), was already passing without modifications.
But the second test was not -- ak.from_numpy().
Fix is in _layout.py. No policy options are passed to _layout.from_arraylib().

* Adding/passing a primitive_policy kwarg

From from_numpy, from_cupy, from_jax, and from_dlpack
To from_arraylib

---------

Co-authored-by: Ianna Osborne <[email protected]>
  • Loading branch information
tcawlfield and ianna authored Jun 24, 2024
1 parent 03f6169 commit db6cece
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 10 deletions.
12 changes: 11 additions & 1 deletion src/awkward/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ def maybe_highlevel_to_lowlevel(obj):
return obj


def from_arraylib(array, regulararray, recordarray):
def from_arraylib(
array,
regulararray,
recordarray,
primitive_policy: Literal["error", "promote", "pass-through"] = "promote",
):
from awkward.contents import (
ByteMaskedArray,
ListArray,
Expand Down Expand Up @@ -341,6 +346,11 @@ def attach(x):
if array.dtype == np.dtype("O"):
raise TypeError("Awkward Array does not support arrays with object dtypes.")

if primitive_policy == "error" and array.ndim == 0:
raise TypeError(
f"Encountered a scalar ({type(array).__name__}), but scalar conversion/promotion is disabled"
)

if isinstance(array, numpy.ma.MaskedArray):
mask = numpy.ma.getmask(array)
array = numpy.ma.getdata(array)
Expand Down
12 changes: 10 additions & 2 deletions src/awkward/operations/ak_from_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@


@high_level_function()
def from_cupy(array, *, regulararray=False, highlevel=True, behavior=None, attrs=None):
def from_cupy(
array,
*,
regulararray=False,
highlevel=True,
behavior=None,
primitive_policy="error",
attrs=None,
):
"""
Args:
array (cp.ndarray): The CuPy array to convert into an Awkward Array.
Expand All @@ -36,7 +44,7 @@ def from_cupy(array, *, regulararray=False, highlevel=True, behavior=None, attrs
See also #ak.to_cupy, #ak.from_numpy and #ak.from_jax.
"""
return wrap_layout(
from_arraylib(array, regulararray, False),
from_arraylib(array, regulararray, False, primitive_policy=primitive_policy),
highlevel=highlevel,
behavior=behavior,
)
3 changes: 2 additions & 1 deletion src/awkward/operations/ak_from_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def from_dlpack(
regulararray=False,
highlevel=True,
behavior=None,
primitive_policy="error",
attrs=None,
):
"""
Expand Down Expand Up @@ -77,7 +78,7 @@ def from_dlpack(

array = nplike.from_dlpack(array)
return wrap_layout(
from_arraylib(array, regulararray, False),
from_arraylib(array, regulararray, False, primitive_policy=primitive_policy),
highlevel=highlevel,
behavior=behavior,
)
12 changes: 10 additions & 2 deletions src/awkward/operations/ak_from_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@


@high_level_function()
def from_jax(array, *, regulararray=False, highlevel=True, behavior=None, attrs=None):
def from_jax(
array,
*,
regulararray=False,
highlevel=True,
behavior=None,
attrs=None,
primitive_policy="error",
):
"""
Args:
array (jax.numpy.DeviceArray): The JAX DeviceArray to convert into an Awkward Array.
Expand Down Expand Up @@ -38,7 +46,7 @@ def from_jax(array, *, regulararray=False, highlevel=True, behavior=None, attrs=
"""
jax.assert_registered()
return wrap_layout(
from_arraylib(array, regulararray, False),
from_arraylib(array, regulararray, False, primitive_policy=primitive_policy),
highlevel=highlevel,
behavior=behavior,
)
5 changes: 4 additions & 1 deletion src/awkward/operations/ak_from_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def from_numpy(
recordarray=True,
highlevel=True,
behavior=None,
primitive_policy="error",
attrs=None,
):
"""
Expand Down Expand Up @@ -52,7 +53,9 @@ def from_numpy(
See also #ak.to_numpy and #ak.from_cupy.
"""
return wrap_layout(
from_arraylib(array, regulararray, recordarray),
from_arraylib(
array, regulararray, recordarray, primitive_policy=primitive_policy
),
highlevel=highlevel,
behavior=behavior,
)
16 changes: 13 additions & 3 deletions src/awkward/operations/ak_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,27 @@ def _impl(
regulararray=regulararray,
recordarray=True,
highlevel=False,
primitive_policy=primitive_policy,
)
return _handle_array_like(
obj, promoted_layout, primitive_policy=primitive_policy
)
elif Cupy.is_own_array(obj):
promoted_layout = ak.operations.from_cupy(
obj, regulararray=regulararray, highlevel=False
obj,
regulararray=regulararray,
highlevel=False,
primitive_policy=primitive_policy,
)
return _handle_array_like(
obj, promoted_layout, primitive_policy=primitive_policy
)
elif Jax.is_own_array(obj):
promoted_layout = ak.operations.from_jax(
obj, regulararray=regulararray, highlevel=False
obj,
regulararray=regulararray,
highlevel=False,
primitive_policy=primitive_policy,
)
return _handle_array_like(
obj, promoted_layout, primitive_policy=primitive_policy
Expand All @@ -215,14 +222,17 @@ def _impl(
elif ak._util.in_module(obj, "pyarrow"):
return ak.operations.from_arrow(obj, highlevel=False)
elif hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"):
return ak.operations.from_dlpack(obj, highlevel=False)
return ak.operations.from_dlpack(
obj, highlevel=False, primitive_policy=primitive_policy
)
# Typed scalars
elif isinstance(obj, np.generic):
promoted_layout = ak.operations.from_numpy(
numpy.asarray(obj),
regulararray=regulararray,
recordarray=True,
highlevel=False,
primitive_policy=primitive_policy,
)
return _handle_array_like(
obj, promoted_layout, primitive_policy=primitive_policy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest

import awkward as ak


def test_akarray_from_zero_dim_nparray():
np_scalar = np.array(2.7) # A kind of scalar in numpy.
assert np_scalar.ndim == 0 and np_scalar.shape == ()
with pytest.raises(TypeError):
# Conversion to ak.Array ought to throw here:
b = ak.Array(np_scalar) # (bugged) value: <Array [2.7] type='1 * int64'>
# Now we're failing. Here's why.
c = ak.to_numpy(b) # value: array([2.7])
assert np_scalar.shape == c.shape # this fails

with pytest.raises(TypeError):
b = ak.from_numpy(np_scalar)
c = ak.to_numpy(b)
assert np_scalar.shape == c.shape

0 comments on commit db6cece

Please sign in to comment.