Skip to content

Commit

Permalink
Updated the make_array() function.
Browse files Browse the repository at this point in the history
Before the `order` argument was a `Literal` but this caused more truble now it is a string.
  • Loading branch information
philip-paul-mueller committed Sep 25, 2024
1 parent a546932 commit 090c3a2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def _test_impl_reshaping(
src_shape: Sequence[int], dst_shape: Sequence[int], order: str = "C"
) -> None:
"""Performs a reshaping from `src_shape` to `dst_shape`."""
a = testutil.make_array(src_shape)
a = np.array(a, order=order) # type: ignore[call-overload] # MyPy wants a literal as order.
a = testutil.make_array(src_shape, order=order)

def testee(a: np.ndarray) -> jax.Array:
return jnp.reshape(a, dst_shape)
Expand Down
6 changes: 3 additions & 3 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Any

import numpy as np

Expand All @@ -23,7 +23,7 @@
def make_array(
shape: Sequence[int] | int,
dtype: type = np.float64,
order: Literal[None, "K", "A", "C", "F"] = "C",
order: str = "C",
low: Any = None,
high: Any = None,
) -> np.ndarray:
Expand Down Expand Up @@ -69,7 +69,7 @@ def make_array(
res = low + (high - low) * res
assert (low is None) == (high is None)

return np.array(res, order=order, dtype=dtype)
return np.array(res, order=order, dtype=dtype) # type: ignore[call-overload] # Because we use `str` as `order`.


def set_active_primitive_translators_to(
Expand Down

0 comments on commit 090c3a2

Please sign in to comment.