From 5c0b3de6164b9fd7e28ddd9ffa2ec7dbbc234f08 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 30 Apr 2024 22:02:32 +0000 Subject: [PATCH] Refactor array_api namespace, relying more directly on jax.numpy --- jax/_src/dtypes.py | 7 +- jax/experimental/array_api/__init__.py | 164 ++++---- jax/experimental/array_api/_array_methods.py | 2 +- jax/experimental/array_api/_constants.py | 21 - .../array_api/_creation_functions.py | 39 +- .../array_api/_data_type_functions.py | 180 +------- jax/experimental/array_api/_dtypes.py | 29 -- .../array_api/_elementwise_functions.py | 390 +----------------- jax/experimental/array_api/_fft_functions.py | 51 +-- .../array_api/_indexing_functions.py | 18 - .../array_api/_linear_algebra_functions.py | 128 +----- .../array_api/_manipulation_functions.py | 71 +--- .../array_api/_searching_functions.py | 48 --- jax/experimental/array_api/_set_functions.py | 35 -- .../array_api/_sorting_functions.py | 28 -- .../array_api/_statistical_functions.py | 37 +- .../array_api/_utility_functions.py | 11 +- jax/experimental/array_api/_version.py | 2 +- jax/experimental/array_api/fft.py | 23 +- jax/experimental/array_api/linalg.py | 15 +- jax/experimental/array_api/skips.txt | 6 + pyproject.toml | 5 + tests/array_api_test.py | 6 +- 23 files changed, 171 insertions(+), 1145 deletions(-) delete mode 100644 jax/experimental/array_api/_constants.py delete mode 100644 jax/experimental/array_api/_dtypes.py delete mode 100644 jax/experimental/array_api/_indexing_functions.py delete mode 100644 jax/experimental/array_api/_searching_functions.py delete mode 100644 jax/experimental/array_api/_set_functions.py delete mode 100644 jax/experimental/array_api/_sorting_functions.py diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index c4ea985b553d..8f065122697a 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -467,8 +467,10 @@ def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bo options.update(_dtype_kinds[kind]) elif isinstance(kind, np.dtype): options.add(kind) + # Check for _ScalarMeta without referencing the class directly + elif issubclass(kind.__class__, type) and isinstance(getattr(kind, 'dtype'), np.dtype): + options.add(kind.dtype) else: - # TODO(jakevdp): should we handle scalar types or ScalarMeta here? raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}") return the_dtype in options @@ -652,6 +654,9 @@ def check_valid_dtype(dtype: DType) -> None: raise TypeError(f"Dtype {dtype} is not a valid JAX array " "type. Only arrays of numeric types are supported by JAX.") +def is_valid_dtype(dtype: DType) -> bool: + return dtype in _jax_dtype_set + def dtype(x: Any, *, canonicalize: bool = False) -> DType: """Return the dtype object for a value or type, optionally canonicalized based on X64 mode.""" if x is None: diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index d8c42bc56b3e..7912c4a276a0 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -21,7 +21,7 @@ >>> from jax.experimental import array_api as xp >>> xp.__array_api_version__ - '2022.12' + '2023.12' >>> arr = xp.arange(1000) @@ -38,64 +38,17 @@ from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__ -from jax.experimental.array_api import ( - fft as fft, - linalg as linalg, -) +from jax.experimental.array_api import fft as fft +from jax.experimental.array_api import linalg as linalg -from jax.experimental.array_api._constants import ( +from jax.numpy import ( e as e, inf as inf, nan as nan, newaxis as newaxis, pi as pi, -) - -from jax.experimental.array_api._creation_functions import ( - arange as arange, - asarray as asarray, - empty as empty, - empty_like as empty_like, - eye as eye, - from_dlpack as from_dlpack, - full as full, - full_like as full_like, - linspace as linspace, - meshgrid as meshgrid, - ones as ones, - ones_like as ones_like, tril as tril, triu as triu, - zeros as zeros, - zeros_like as zeros_like, -) - -from jax.experimental.array_api._data_type_functions import ( - astype as astype, - can_cast as can_cast, - finfo as finfo, - iinfo as iinfo, - isdtype as isdtype, - result_type as result_type, -) - -from jax.experimental.array_api._dtypes import ( - bool as bool, - int8 as int8, - int16 as int16, - int32 as int32, - int64 as int64, - uint8 as uint8, - uint16 as uint16, - uint32 as uint32, - uint64 as uint64, - float32 as float32, - float64 as float64, - complex64 as complex64, - complex128 as complex128, -) - -from jax.experimental.array_api._elementwise_functions import ( abs as abs, acos as acos, acosh as acosh, @@ -111,8 +64,6 @@ bitwise_or as bitwise_or, bitwise_right_shift as bitwise_right_shift, bitwise_xor as bitwise_xor, - ceil as ceil, - clip as clip, conj as conj, copysign as copysign, cos as cos, @@ -121,11 +72,9 @@ equal as equal, exp as exp, expm1 as expm1, - floor as floor, floor_divide as floor_divide, greater as greater, greater_equal as greater_equal, - hypot as hypot, imag as imag, isfinite as isfinite, isinf as isinf, @@ -137,10 +86,7 @@ log1p as log1p, log2 as log2, logaddexp as logaddexp, - logical_and as logical_and, logical_not as logical_not, - logical_or as logical_or, - logical_xor as logical_xor, maximum as maximum, minimum as minimum, multiply as multiply, @@ -151,7 +97,6 @@ real as real, remainder as remainder, round as round, - sign as sign, signbit as signbit, sin as sin, sinh as sinh, @@ -159,15 +104,8 @@ square as square, subtract as subtract, tan as tan, - tanh as tanh, - trunc as trunc, -) - -from jax.experimental.array_api._indexing_functions import ( take as take, -) - -from jax.experimental.array_api._manipulation_functions import ( + tanh as tanh, broadcast_arrays as broadcast_arrays, broadcast_to as broadcast_to, concat as concat, @@ -176,56 +114,98 @@ moveaxis as moveaxis, permute_dims as permute_dims, repeat as repeat, - reshape as reshape, roll as roll, squeeze as squeeze, stack as stack, tile as tile, unstack as unstack, -) - -from jax.experimental.array_api._searching_functions import ( argmax as argmax, argmin as argmin, - nonzero as nonzero, searchsorted as searchsorted, where as where, -) - -from jax.experimental.array_api._set_functions import ( unique_all as unique_all, unique_counts as unique_counts, unique_inverse as unique_inverse, unique_values as unique_values, -) - -from jax.experimental.array_api._sorting_functions import ( argsort as argsort, sort as sort, -) - -from jax.experimental.array_api._statistical_functions import ( cumulative_sum as cumulative_sum, max as max, mean as mean, min as min, - prod as prod, - std as std, - sum as sum, - var as var -) - -from jax.experimental.array_api._utility_functions import ( - __array_namespace_info__ as __array_namespace_info__, all as all, any as any, -) - -from jax.experimental.array_api._linear_algebra_functions import ( + from_dlpack as from_dlpack, + meshgrid as meshgrid, + empty as empty, + empty_like as empty_like, + full as full, + full_like as full_like, + ones as ones, + ones_like as ones_like, + zeros as zeros, + zeros_like as zeros_like, + can_cast as can_cast, + isdtype as isdtype, + result_type as result_type, + iinfo as iinfo, + sign as sign, + nonzero as nonzero, + prod as prod, + sum as sum, matmul as matmul, matrix_transpose as matrix_transpose, tensordot as tensordot, vecdot as vecdot, + bool as bool, + int8 as int8, + int16 as int16, + int32 as int32, + int64 as int64, + uint8 as uint8, + uint16 as uint16, + uint32 as uint32, + uint64 as uint64, + float32 as float32, + float64 as float64, + complex64 as complex64, + complex128 as complex128, +) + +from jax.experimental.array_api._manipulation_functions import ( + reshape as reshape, +) + +from jax.experimental.array_api._creation_functions import ( + arange as arange, + asarray as asarray, + eye as eye, + linspace as linspace, +) + +from jax.experimental.array_api._data_type_functions import ( + astype as astype, + finfo as finfo, +) + +from jax.experimental.array_api._elementwise_functions import ( + ceil as ceil, + clip as clip, + floor as floor, + hypot as hypot, + logical_and as logical_and, + logical_or as logical_or, + logical_xor as logical_xor, + trunc as trunc, +) + +from jax.experimental.array_api._statistical_functions import ( + std as std, + var as var, +) + +from jax.experimental.array_api._utility_functions import ( + __array_namespace_info__ as __array_namespace_info__, ) from jax.experimental.array_api import _array_methods diff --git a/jax/experimental/array_api/_array_methods.py b/jax/experimental/array_api/_array_methods.py index ca5bca356258..2b071db573a8 100644 --- a/jax/experimental/array_api/_array_methods.py +++ b/jax/experimental/array_api/_array_methods.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Callable +from typing import Any import jax from jax._src.array import ArrayImpl diff --git a/jax/experimental/array_api/_constants.py b/jax/experimental/array_api/_constants.py deleted file mode 100644 index e6f0d542ae79..000000000000 --- a/jax/experimental/array_api/_constants.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -e = np.e -inf = np.inf -nan = np.nan -newaxis = np.newaxis -pi = np.pi diff --git a/jax/experimental/array_api/_creation_functions.py b/jax/experimental/array_api/_creation_functions.py index 2fd9be97ba27..99b8e3ed4465 100644 --- a/jax/experimental/array_api/_creation_functions.py +++ b/jax/experimental/array_api/_creation_functions.py @@ -16,53 +16,16 @@ import jax import jax.numpy as jnp -from jax._src.lib import xla_client as xc -from jax._src.sharding import Sharding +# TODO(micky774): Deprecate after adding device argument to jax.numpy functions def arange(start, /, stop=None, step=1, *, dtype=None, device=None): return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device) def asarray(obj, /, *, dtype=None, device=None, copy=None): return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device) -def empty(shape, *, dtype=None, device=None): - return jax.device_put(jnp.empty(shape, dtype=dtype), device=device) - -def empty_like(x, /, *, dtype=None, device=None): - return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device) - def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None): return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device) -def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None): - return jnp.from_dlpack(x, device=device, copy=copy) - -def full(shape, fill_value, *, dtype=None, device=None): - return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device) - -def full_like(x, /, fill_value, *, dtype=None, device=None): - return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device) - def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True): return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device) - -def meshgrid(*arrays, indexing='xy'): - return jnp.meshgrid(*arrays, indexing=indexing) - -def ones(shape, *, dtype=None, device=None): - return jax.device_put(jnp.ones(shape, dtype=dtype), device=device) - -def ones_like(x, /, *, dtype=None, device=None): - return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device) - -def tril(x, /, *, k=0): - return jnp.tril(x, k=k) - -def triu(x, /, *, k=0): - return jnp.triu(x, k=k) - -def zeros(shape, *, dtype=None, device=None): - return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device) - -def zeros_like(x, /, *, dtype=None, device=None): - return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 770d264c1c07..988840c31381 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -15,120 +15,32 @@ from __future__ import annotations import builtins -import functools from typing import NamedTuple -import jax -import jax.numpy as jnp +import numpy as np +import jax.numpy as jnp from jax._src.lib import xla_client as xc from jax._src.sharding import Sharding from jax._src import dtypes as _dtypes -from jax.experimental.array_api._dtypes import ( - bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, complex64, complex128 -) - -_valid_dtypes = { - bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, complex64, complex128 -} - -_promotion_table = { - (bool, bool): bool, - (int8, int8): int8, - (int8, int16): int16, - (int8, int32): int32, - (int8, int64): int64, - (int8, uint8): int16, - (int8, uint16): int32, - (int8, uint32): int64, - (int16, int8): int16, - (int16, int16): int16, - (int16, int32): int32, - (int16, int64): int64, - (int16, uint8): int16, - (int16, uint16): int32, - (int16, uint32): int64, - (int32, int8): int32, - (int32, int16): int32, - (int32, int32): int32, - (int32, int64): int64, - (int32, uint8): int32, - (int32, uint16): int32, - (int32, uint32): int64, - (int64, int8): int64, - (int64, int16): int64, - (int64, int32): int64, - (int64, int64): int64, - (int64, uint8): int64, - (int64, uint16): int64, - (int64, uint32): int64, - (uint8, int8): int16, - (uint8, int16): int16, - (uint8, int32): int32, - (uint8, int64): int64, - (uint8, uint8): uint8, - (uint8, uint16): uint16, - (uint8, uint32): uint32, - (uint8, uint64): uint64, - (uint16, int8): int32, - (uint16, int16): int32, - (uint16, int32): int32, - (uint16, int64): int64, - (uint16, uint8): uint16, - (uint16, uint16): uint16, - (uint16, uint32): uint32, - (uint16, uint64): uint64, - (uint32, int8): int64, - (uint32, int16): int64, - (uint32, int32): int64, - (uint32, int64): int64, - (uint32, uint8): uint32, - (uint32, uint16): uint32, - (uint32, uint32): uint32, - (uint32, uint64): uint64, - (uint64, uint8): uint64, - (uint64, uint16): uint64, - (uint64, uint32): uint64, - (uint64, uint64): uint64, - (float32, float32): float32, - (float32, float64): float64, - (float32, complex64): complex64, - (float32, complex128): complex128, - (float64, float32): float64, - (float64, float64): float64, - (float64, complex64): complex128, - (float64, complex128): complex128, - (complex64, float32): complex64, - (complex64, float64): complex128, - (complex64, complex64): complex64, - (complex64, complex128): complex128, - (complex128, float32): complex128, - (complex128, float64): complex128, - (complex128, complex64): complex128, - (complex128, complex128): complex128, -} - - -def _is_valid_dtype(t): - try: - return t in _valid_dtypes - except TypeError: - return False - - -def _promote_types(t1, t2): - if not _is_valid_dtype(t1): - raise ValueError(f"{t1} is not a valid dtype") - if not _is_valid_dtype(t2): - raise ValueError(f"{t2} is not a valid dtype") - if result := _promotion_table.get((t1, t2), None): - return result - else: - raise ValueError("No promotion path for {t1} & {t2}") - +# TODO(micky774): Update jax.numpy dtypes to dtype *objects* +bool = np.dtype('bool') +int8 = np.dtype('int8') +int16 = np.dtype('int16') +int32 = np.dtype('int32') +int64 = np.dtype('int64') +uint8 = np.dtype('uint8') +uint16 = np.dtype('uint16') +uint32 = np.dtype('uint32') +uint64 = np.dtype('uint64') +float32 = np.dtype('float32') +float64 = np.dtype('float64') +complex64 = np.dtype('complex64') +complex128 = np.dtype('complex128') + + +# TODO(micky774): Remove when jax.numpy.astype is deprecation is completed def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) if ( @@ -144,21 +56,6 @@ def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Shard return jnp.astype(x, dtype, copy=copy, device=device) -def can_cast(from_, to, /): - if isinstance(from_, jax.Array): - from_ = from_.dtype - if not _is_valid_dtype(from_): - raise ValueError(f"{from_} is not a valid dtype") - if not _is_valid_dtype(to): - raise ValueError(f"{to} is not a valid dtype") - try: - result = _promote_types(from_, to) - except ValueError: - return False - else: - return result == to - - class FInfo(NamedTuple): bits: int eps: float @@ -167,14 +64,8 @@ class FInfo(NamedTuple): smallest_normal: float dtype: jnp.dtype - -class IInfo(NamedTuple): - bits: int - max: int - min: int - dtype: jnp.dtype - - +# TODO(micky774): Update jax.numpy.finfo so that its attributes are python +# floats def finfo(type, /) -> FInfo: info = jnp.finfo(type) return FInfo( @@ -186,34 +77,7 @@ def finfo(type, /) -> FInfo: dtype=jnp.dtype(type) ) - -def iinfo(type, /) -> IInfo: - info = jnp.iinfo(type) - return IInfo(bits=info.bits, max=info.max, min=info.min, dtype=jnp.dtype(type)) - - -def isdtype(dtype, kind): - return jax.numpy.isdtype(dtype, kind) - - -def result_type(*arrays_and_dtypes): - dtypes = [] - for val in arrays_and_dtypes: - if isinstance(val, (builtins.bool, int, float, complex)): - val = jax.numpy.array(val) - if isinstance(val, jax.Array): - val = val.dtype - if _is_valid_dtype(val): - dtypes.append(val) - else: - raise ValueError(f"{val} is not a valid dtype") - if len(dtypes) == 0: - raise ValueError("result_type requires at least one argument") - if len(dtypes) == 1: - return dtypes[0] - return functools.reduce(_promote_types, dtypes) - - +# TODO(micky774): Update utility to only promote integral types def _promote_to_default_dtype(x): if x.dtype.kind == 'b': return x diff --git a/jax/experimental/array_api/_dtypes.py b/jax/experimental/array_api/_dtypes.py deleted file mode 100644 index 72229bfc28af..000000000000 --- a/jax/experimental/array_api/_dtypes.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -bool = np.dtype('bool') -int8 = np.dtype('int8') -int16 = np.dtype('int16') -int32 = np.dtype('int32') -int64 = np.dtype('int64') -uint8 = np.dtype('uint8') -uint16 = np.dtype('uint16') -uint32 = np.dtype('uint32') -uint64 = np.dtype('uint64') -float32 = np.dtype('float32') -float64 = np.dtype('float64') -complex64 = np.dtype('complex64') -complex128 = np.dtype('complex128') diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index f6f184dcf726..6162f8750fd5 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -13,125 +13,26 @@ # limitations under the License. import jax +from jax.numpy import isdtype from jax._src.dtypes import issubdtype -from jax.experimental.array_api._data_type_functions import ( - result_type as _result_type, - isdtype as _isdtype, -) - - -def _promote_dtypes(name, *args): - assert isinstance(name, str) - if not all(isinstance(arg, (bool, int, float, complex, jax.Array)) - for arg in args): - raise ValueError(f"{name}: inputs must be arrays; got types {[type(arg) for arg in args]}") - dtype = _result_type(*args) - return [jax.numpy.asarray(arg).astype(dtype) for arg in args] - - -def abs(x, /): - """Calculates the absolute value for each element x_i of the input array x.""" - x, = _promote_dtypes("abs", x) - return jax.numpy.abs(x) - - -def acos(x, /): - """Calculates an implementation-dependent approximation of the principal value of the inverse cosine for each element x_i of the input array x.""" - x, = _promote_dtypes("acos", x) - return jax.numpy.acos(x) - -def acosh(x, /): - """Calculates an implementation-dependent approximation to the inverse hyperbolic cosine for each element x_i of the input array x.""" - x, = _promote_dtypes("acos", x) - return jax.numpy.acosh(x) - - -def add(x1, x2, /): - """Calculates the sum for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("add", x1, x2) - return jax.numpy.add(x1, x2) - - -def asin(x, /): - """Calculates an implementation-dependent approximation of the principal value of the inverse sine for each element x_i of the input array x.""" - x, = _promote_dtypes("asin", x) - return jax.numpy.asin(x) - - -def asinh(x, /): - """Calculates an implementation-dependent approximation to the inverse hyperbolic sine for each element x_i in the input array x.""" - x, = _promote_dtypes("asinh", x) - return jax.numpy.asinh(x) - - -def atan(x, /): - """Calculates an implementation-dependent approximation of the principal value of the inverse tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("atan", x) - return jax.numpy.atan(x) - - -def atan2(x1, x2, /): - """Calculates an implementation-dependent approximation of the inverse tangent of the quotient x1/x2, having domain [-infinity, +infinity] x [-infinity, +infinity] (where the x notation denotes the set of ordered pairs of elements (x1_i, x2_i)) and codomain [-π, +π], for each pair of elements (x1_i, x2_i) of the input arrays x1 and x2, respectively.""" - x1, x2 = _promote_dtypes("atan2", x1, x2) - return jax.numpy.arctan2(x1, x2) - - -def atanh(x, /): - """Calculates an implementation-dependent approximation to the inverse hyperbolic tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("atanh", x) - return jax.numpy.atanh(x) - - -def bitwise_and(x1, x2, /): - """Computes the bitwise AND of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_and", x1, x2) - return jax.numpy.bitwise_and(x1, x2) - - -def bitwise_left_shift(x1, x2, /): - """Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i.""" - x1, x2 = _promote_dtypes("bitwise_left_shift", x1, x2) - return jax.numpy.bitwise_left_shift(x1, x2) - - -def bitwise_invert(x, /): - """Inverts (flips) each bit for each element x_i of the input array x.""" - x, = _promote_dtypes("bitwise_invert", x) - return jax.numpy.bitwise_invert(x) - - -def bitwise_or(x1, x2, /): - """Computes the bitwise OR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_or", x1, x2) - return jax.numpy.bitwise_or(x1, x2) - - -def bitwise_right_shift(x1, x2, /): - """Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_right_shift", x1, x2) - return jax.numpy.bitwise_right_shift(x1, x2) - - -def bitwise_xor(x1, x2, /): - """Computes the bitwise XOR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("bitwise_xor", x1, x2) - return jax.numpy.bitwise_xor(x1, x2) +from jax._src.numpy.util import promote_args +# TODO(micky774): Update jnp.ceil to preserve integral dtype def ceil(x, /): """Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i.""" - x, = _promote_dtypes("ceil", x) - if _isdtype(x.dtype, "integral"): + x, = promote_args("ceil", x) + if isdtype(x.dtype, "integral"): return x return jax.numpy.ceil(x) +# TODO(micky774): Remove when jnp.clip deprecation is completed +# (began 2024-4-2) and default behavior is Array API 2023 compliant def clip(x, /, min=None, max=None): """Returns the complex conjugate for each element x_i of the input array x.""" - x, = _promote_dtypes("clip", x) + x, = promote_args("clip", x) - # TODO(micky774): Remove when jnp.clip deprecation is completed - # (began 2024-4-2) and default behavior is Array API 2023 compliant if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)): raise ValueError( "Clip received a complex value either through the input or the min/max " @@ -142,85 +43,21 @@ def clip(x, /, min=None, max=None): return jax.numpy.clip(x, min=min, max=max) -def conj(x, /): - """Returns the complex conjugate for each element x_i of the input array x.""" - x, = _promote_dtypes("conj", x) - return jax.numpy.conj(x) - - -def copysign(x1, x2, /): - """Composes a floating-point value with the magnitude of x1_i and the sign of x2_i for each element of the input array x1.""" - return jax.numpy.copysign(x1, x2) - - -def cos(x, /): - """Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x.""" - x, = _promote_dtypes("cos", x) - return jax.numpy.cos(x) - - -def cosh(x, /): - """Calculates an implementation-dependent approximation to the hyperbolic cosine for each element x_i in the input array x.""" - x, = _promote_dtypes("cosh", x) - return jax.numpy.cosh(x) - - -def divide(x1, x2, /): - """Calculates the division of each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("divide", x1, x2) - return jax.numpy.divide(x1, x2) - - -def equal(x1, x2, /): - """Computes the truth value of x1_i == x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("equal", x1, x2) - return jax.numpy.equal(x1, x2) - - -def exp(x, /): - """Calculates an implementation-dependent approximation to the exponential function for each element x_i of the input array x (e raised to the power of x_i, where e is the base of the natural logarithm).""" - x, = _promote_dtypes("exp", x) - return jax.numpy.exp(x) - - -def expm1(x, /): - """Calculates an implementation-dependent approximation to exp(x)-1 for each element x_i of the input array x.""" - x, = _promote_dtypes("expm1", x) - return jax.numpy.expm1(x) - - +# TODO(micky774): Update jnp.floor to preserve integral dtype def floor(x, /): """Rounds each element x_i of the input array x to the greatest (i.e., closest to +infinity) integer-valued number that is not greater than x_i.""" - x, = _promote_dtypes("floor", x) - if _isdtype(x.dtype, "integral"): + x, = promote_args("floor", x) + if isdtype(x.dtype, "integral"): return x return jax.numpy.floor(x) -def floor_divide(x1, x2, /): - """Rounds the result of dividing each element x1_i of the input array x1 by the respective element x2_i of the input array x2 to the greatest (i.e., closest to +infinity) integer-value number that is not greater than the division result.""" - x1, x2 = _promote_dtypes("floor_divide", x1, x2) - return jax.numpy.floor_divide(x1, x2) - - -def greater(x1, x2, /): - """Computes the truth value of x1_i > x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("greater", x1, x2) - return jax.numpy.greater(x1, x2) - - -def greater_equal(x1, x2, /): - """Computes the truth value of x1_i >= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("greater_equal", x1, x2) - return jax.numpy.greater_equal(x1, x2) - - +# TODO(micky774): Remove when jnp.hypot deprecation is completed +# (began 2024-4-14) and default behavior is Array API 2023 compliant def hypot(x1, x2, /): """Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("hypot", x1, x2) + x1, x2 = promote_args("hypot", x1, x2) - # TODO(micky774): Remove when jnp.hypot deprecation is completed - # (began 2024-4-14) and default behavior is Array API 2023 compliant if issubdtype(x1.dtype, jax.numpy.complexfloating): raise ValueError( "hypot does not support complex-valued inputs. Please convert to real " @@ -229,214 +66,29 @@ def hypot(x1, x2, /): return jax.numpy.hypot(x1, x2) -def imag(x, /): - """Returns the imaginary component of a complex number for each element x_i of the input array x.""" - x, = _promote_dtypes("imag", x) - return jax.numpy.imag(x) - - -def isfinite(x, /): - """Tests each element x_i of the input array x to determine if finite.""" - x, = _promote_dtypes("isfinite", x) - return jax.numpy.isfinite(x) - - -def isinf(x, /): - """Tests each element x_i of the input array x to determine if equal to positive or negative infinity.""" - x, = _promote_dtypes("isinf", x) - return jax.numpy.isinf(x) - - -def isnan(x, /): - """Tests each element x_i of the input array x to determine whether the element is NaN.""" - x, = _promote_dtypes("isnan", x) - return jax.numpy.isnan(x) - - -def less(x1, x2, /): - """Computes the truth value of x1_i < x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("less", x1, x2) - return jax.numpy.less(x1, x2) - - -def less_equal(x1, x2, /): - """Computes the truth value of x1_i <= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("less_equal", x1, x2) - return jax.numpy.less_equal(x1, x2) - - -def log(x, /): - """Calculates an implementation-dependent approximation to the natural (base e) logarithm for each element x_i of the input array x.""" - x, = _promote_dtypes("log", x) - return jax.numpy.log(x) - - -def log1p(x, /): - """Calculates an implementation-dependent approximation to log(1+x), where log refers to the natural (base e) logarithm, for each element x_i of the input array x.""" - x, = _promote_dtypes("log", x) - return jax.numpy.log1p(x) - - -def log2(x, /): - """Calculates an implementation-dependent approximation to the base 2 logarithm for each element x_i of the input array x.""" - x, = _promote_dtypes("log2", x) - return jax.numpy.log2(x) - - -def log10(x, /): - """Calculates an implementation-dependent approximation to the base 10 logarithm for each element x_i of the input array x.""" - x, = _promote_dtypes("log10", x) - return jax.numpy.log10(x) - - -def logaddexp(x1, x2, /): - """Calculates the logarithm of the sum of exponentiations log(exp(x1) + exp(x2)) for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logaddexp", x1, x2) - return jax.numpy.logaddexp(x1, x2) - - +# TODO(micky774): Update jnp.logical_* binary ops signatures to positional only def logical_and(x1, x2, /): """Computes the logical AND for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logical_and", x1, x2) + x1, x2 = promote_args("logical_and", x1, x2) return jax.numpy.logical_and(x1, x2) -def logical_not(x, /): - """Computes the logical NOT for each element x_i of the input array x.""" - x, = _promote_dtypes("logical_not", x) - return jax.numpy.logical_not(x) - - def logical_or(x1, x2, /): """Computes the logical OR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logical_or", x1, x2) + x1, x2 = promote_args("logical_or", x1, x2) return jax.numpy.logical_or(x1, x2) def logical_xor(x1, x2, /): """Computes the logical XOR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("logical_xor", x1, x2) + x1, x2 = promote_args("logical_xor", x1, x2) return jax.numpy.logical_xor(x1, x2) -def maximum(x1, x2, /): - """Computes the maximum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("maximum", x1, x2) - return jax.numpy.maximum(x1, x2) - - -def minimum(x1, x2, /): - """Computes the minimum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("minimum", x1, x2) - return jax.numpy.minimum(x1, x2) - - -def multiply(x1, x2, /): - """Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("multiply", x1, x2) - return jax.numpy.multiply(x1, x2) - - -def negative(x, /): - """Computes the numerical negative of each element x_i (i.e., y_i = -x_i) of the input array x.""" - x, = _promote_dtypes("negative", x) - return jax.numpy.negative(x) - - -def not_equal(x1, x2, /): - """Computes the truth value of x1_i != x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("not_equal", x1, x2) - return jax.numpy.not_equal(x1, x2) - - -def positive(x, /): - """Computes the numerical positive of each element x_i (i.e., y_i = +x_i) of the input array x.""" - x, = _promote_dtypes("positive", x) - return x - - -def pow(x1, x2, /): - """Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2.""" - x1, x2 = _promote_dtypes("pow", x1, x2) - return jax.numpy.pow(x1, x2) - - -def real(x, /): - """Returns the real component of a complex number for each element x_i of the input array x.""" - x, = _promote_dtypes("real", x) - return jax.numpy.real(x) - - -def remainder(x1, x2, /): - """Returns the remainder of division for each element x1_i of the input array x1 and the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("remainder", x1, x2) - return jax.numpy.remainder(x1, x2) - - -def round(x, /): - """Rounds each element x_i of the input array x to the nearest integer-valued number.""" - x, = _promote_dtypes("round", x) - return jax.numpy.round(x) - - -def sign(x, /): - """Returns an indication of the sign of a number for each element x_i of the input array x.""" - x, = _promote_dtypes("sign", x) - if _isdtype(x.dtype, "complex floating"): - return x / abs(x) - return jax.numpy.sign(x) - - -def signbit(x, /): - """Determines whether the sign bit is set for each element x_i of the input array x.""" - return jax.numpy.signbit(x) - - -def sin(x, /): - """Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x.""" - x, = _promote_dtypes("sin", x) - return jax.numpy.sin(x) - - -def sinh(x, /): - """Calculates an implementation-dependent approximation to the hyperbolic sine for each element x_i of the input array x.""" - x, = _promote_dtypes("sin", x) - return jax.numpy.sinh(x) - - -def square(x, /): - """Squares each element x_i of the input array x.""" - x, = _promote_dtypes("square", x) - return jax.numpy.square(x) - - -def sqrt(x, /): - """Calculates the principal square root for each element x_i of the input array x.""" - x, = _promote_dtypes("sqrt", x) - return jax.numpy.sqrt(x) - - -def subtract(x1, x2, /): - """Calculates the difference for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" - x1, x2 = _promote_dtypes("subtract", x1, x2) - return jax.numpy.subtract(x1, x2) - - -def tan(x, /): - """Calculates an implementation-dependent approximation to the tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("tan", x) - return jax.numpy.tan(x) - - -def tanh(x, /): - """Calculates an implementation-dependent approximation to the hyperbolic tangent for each element x_i of the input array x.""" - x, = _promote_dtypes("tanh", x) - return jax.numpy.tanh(x) - - +# TODO(micky774): Update jnp.trunc to preserve integral dtype def trunc(x, /): """Rounds each element x_i of the input array x to the nearest integer-valued number that is closer to zero than x_i.""" - x, = _promote_dtypes("trunc", x) - if _isdtype(x.dtype, "integral"): + x, = promote_args("trunc", x) + if isdtype(x.dtype, "integral"): return x return jax.numpy.trunc(x) diff --git a/jax/experimental/array_api/_fft_functions.py b/jax/experimental/array_api/_fft_functions.py index d1e737a424ac..9b51dd628484 100644 --- a/jax/experimental/array_api/_fft_functions.py +++ b/jax/experimental/array_api/_fft_functions.py @@ -14,47 +14,8 @@ import jax.numpy as jnp - -def fft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional discrete Fourier transform.""" - return jnp.fft.fft(x, n=n, axis=axis, norm=norm) - -def ifft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional inverse discrete Fourier transform.""" - return jnp.fft.ifft(x, n=n, axis=axis, norm=norm) - -def fftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional discrete Fourier transform.""" - return jnp.fft.fftn(x, s=s, axes=axes, norm=norm) - -def ifftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional inverse discrete Fourier transform.""" - return jnp.fft.ifftn(x, s=s, axes=axes, norm=norm) - -def rfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional discrete Fourier transform for real-valued input.""" - return jnp.fft.rfft(x, n=n, axis=axis, norm=norm) - -def irfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional inverse of rfft for complex-valued input.""" - return jnp.fft.irfft(x, n=n, axis=axis, norm=norm) - -def rfftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional discrete Fourier transform for real-valued input.""" - return jnp.fft.rfftn(x, s=s, axes=axes, norm=norm) - -def irfftn(x, /, *, s=None, axes=None, norm='backward'): - """Computes the n-dimensional inverse of rfftn for complex-valued input.""" - return jnp.fft.irfftn(x, s=s, axes=axes, norm=norm) - -def hfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional discrete Fourier transform of a signal with Hermitian symmetry.""" - return jnp.fft.hfft(x, n=n, axis=axis, norm=norm) - -def ihfft(x, /, *, n=None, axis=-1, norm='backward'): - """Computes the one-dimensional inverse discrete Fourier transform of a signal with Hermitian symmetry.""" - return jnp.fft.ihfft(x, n=n, axis=axis, norm=norm) - +# TODO(micky774): Remove after adding device parameter to corresponding jnp.fft +# functions. def fftfreq(n, /, *, d=1.0, device=None): """Returns the discrete Fourier transform sample frequencies.""" return jnp.fft.fftfreq(n, d=d).to_device(device) @@ -62,11 +23,3 @@ def fftfreq(n, /, *, d=1.0, device=None): def rfftfreq(n, /, *, d=1.0, device=None): """Returns the discrete Fourier transform sample frequencies (for rfft and irfft).""" return jnp.fft.rfftfreq(n, d=d).to_device(device) - -def fftshift(x, /, *, axes=None): - """Shift the zero-frequency component to the center of the spectrum.""" - return jnp.fft.fftshift(x, axes=axes) - -def ifftshift(x, /, *, axes=None): - """Inverse of fftshift.""" - return jnp.fft.ifftshift(x, axes=axes) diff --git a/jax/experimental/array_api/_indexing_functions.py b/jax/experimental/array_api/_indexing_functions.py deleted file mode 100644 index 261c81b20351..000000000000 --- a/jax/experimental/array_api/_indexing_functions.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax - -def take(x, indices, /, *, axis): - return jax.numpy.take(x, indices, axis=axis) diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py index 478af83092d7..8234a9f01858 100644 --- a/jax/experimental/array_api/_linear_algebra_functions.py +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -13,140 +13,18 @@ # limitations under the License. import jax -from jax.experimental.array_api._data_type_functions import ( - _promote_to_default_dtype, -) - -def cholesky(x, /, *, upper=False): - """ - Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x. - """ - return jax.numpy.linalg.cholesky(x, upper=upper) - -def cross(x1, x2, /, *, axis=-1): - """ - Returns the cross product of 3-element vectors. - """ - return jax.numpy.linalg.cross(x1, x2, axis=axis) - -def det(x, /): - """ - Returns the determinant of a square matrix (or a stack of square matrices) x. - """ - return jax.numpy.linalg.det(x) - -def diagonal(x, /, *, offset=0): - """ - Returns the specified diagonals of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.diagonal(x, offset=offset) - -def eigh(x, /): - """ - Returns an eigenvalue decomposition of a complex Hermitian or real symmetric matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.eigh(x) - -def eigvalsh(x, /): - """ - Returns the eigenvalues of a complex Hermitian or real symmetric matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.eigvalsh(x) - -def inv(x, /): - """ - Returns the multiplicative inverse of a square matrix (or a stack of square matrices) x. - """ - return jax.numpy.linalg.inv(x) - -def matmul(x1, x2, /): - """Computes the matrix product.""" - return jax.numpy.linalg.matmul(x1, x2) - -def matrix_norm(x, /, *, keepdims=False, ord='fro'): - """ - Computes the matrix norm of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.matrix_norm(x, ord=ord, keepdims=keepdims) - -def matrix_power(x, n, /): - """ - Raises a square matrix (or a stack of square matrices) x to an integer power n. - """ - return jax.numpy.linalg.matrix_power(x, n) +# TODO(micky774): Remove after deprecating tol-->rtol in jnp.linalg.matrix_rank def matrix_rank(x, /, *, rtol=None): """ Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices). """ return jax.numpy.linalg.matrix_rank(x, tol=rtol) -def matrix_transpose(x, /): - """Transposes a matrix (or a stack of matrices) x.""" - return jax.numpy.linalg.matrix_transpose(x) - -def outer(x1, x2, /): - """ - Returns the outer product of two vectors x1 and x2. - """ - return jax.numpy.linalg.outer(x1, x2) - +# TODO(micky774): Remove after deprecating rcond-->rtol in +# jnp.linalg.pinv def pinv(x, /, *, rtol=None): """ Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x. """ return jax.numpy.linalg.pinv(x, rcond=rtol) - -def qr(x, /, *, mode='reduced'): - """ - Returns the QR decomposition of a full column rank matrix (or a stack of matrices). - """ - return jax.numpy.linalg.qr(x, mode=mode) - -def slogdet(x, /): - """ - Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) x. - """ - return jax.numpy.linalg.slogdet(x) - -def solve(x1, x2, /): - """ - Returns the solution of a square system of linear equations with a unique solution. - """ - if x2.ndim == 1: - signature = "(m,m),(m)->(m)" - else: - signature = "(m,m),(m,n)->(m,n)" - return jax.numpy.vectorize(jax.numpy.linalg.solve, signature=signature)(x1, x2) - - -def svd(x, /, *, full_matrices=True): - """ - Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.svd(x, full_matrices=full_matrices) - -def svdvals(x, /): - """ - Returns the singular values of a matrix (or a stack of matrices) x. - """ - return jax.numpy.linalg.svdvals(x) - -def tensordot(x1, x2, /, *, axes=2): - """Returns a tensor contraction of x1 and x2 over specific axes.""" - return jax.numpy.linalg.tensordot(x1, x2, axes=axes) - -def trace(x, /, *, offset=0, dtype=None): - """ - Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x. - """ - x = _promote_to_default_dtype(x) - return jax.numpy.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1) - -def vecdot(x1, x2, /, *, axis=-1): - """Computes the (vector) dot product of two arrays.""" - return jax.numpy.linalg.vecdot(x1, x2, axis=axis) - -def vector_norm(x, /, *, axis=None, keepdims=False, ord=2): - """Computes the vector norm of a vector (or batch of vectors) x.""" - return jax.numpy.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index e0e7a8bccaf6..4bc996c51fdc 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -16,79 +16,10 @@ import jax from jax import Array -from jax.experimental.array_api._data_type_functions import result_type as _result_type - - -def broadcast_arrays(*arrays: Array) -> list[Array]: - """Broadcasts one or more arrays against one another.""" - return jax.numpy.broadcast_arrays(*arrays) - - -def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: - """Broadcasts an array to a specified shape.""" - return jax.numpy.broadcast_to(x, shape=shape) - - -def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array: - """Joins a sequence of arrays along an existing axis.""" - dtype = _result_type(*arrays) - return jax.numpy.concat([arr.astype(dtype) for arr in arrays], axis=axis) - - -def expand_dims(x: Array, /, *, axis: int = 0) -> Array: - """Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis.""" - if axis < -x.ndim - 1 or axis > x.ndim: - raise IndexError(f"{axis=} is out of bounds for array of dimension {x.ndim}") - return jax.numpy.expand_dims(x, axis=axis) - - -def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array: - """Reverses the order of elements in an array along the given axis.""" - return jax.numpy.flip(x, axis=axis) - - -def moveaxis(x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /) -> Array: - """Moves array axes (dimensions) to new positions, while leaving other axes in their original positions.""" - return jax.numpy.moveaxis(x, source, destination) - - -def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: - """Permutes the axes (dimensions) of an array x.""" - return jax.numpy.permute_dims(x, axes=axes) - - -def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: - """Repeats each element of an array a specified number of times on a per-element basis.""" - return jax.numpy.repeat(x, repeats=repeats, axis=axis) +# TODO(micky774): Deprecate newshape-->shape in for array API 2023.12 def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: """Reshapes an array without changing its data.""" del copy # unused return jax.numpy.reshape(x, shape) - - -def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None) -> Array: - """Rolls array elements along a specified axis.""" - return jax.numpy.roll(x, shift=shift, axis=axis) - - -def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: - """Removes singleton dimensions (axes) from x.""" - return jax.numpy.squeeze(x, axis=axis) - - -def stack(arrays: tuple[Array, ...] | list[Array], /, *, axis: int = 0) -> Array: - """Joins a sequence of arrays along a new axis.""" - dtype = _result_type(*arrays) - return jax.numpy.stack(arrays, axis=axis, dtype=dtype) - - -def tile(x: Array, repetitions: tuple[int], /) -> Array: - """Constructs an array by tiling an input array.""" - return jax.numpy.tile(x, repetitions) - - -def unstack(x: Array, /, *, axis: int = 0) -> tuple[Array, ...]: - """Splits an array in a sequence of arrays along the given axis.""" - return jax.numpy.unstack(x, axis=axis) diff --git a/jax/experimental/array_api/_searching_functions.py b/jax/experimental/array_api/_searching_functions.py deleted file mode 100644 index f329e4add813..000000000000 --- a/jax/experimental/array_api/_searching_functions.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax.experimental.array_api._data_type_functions import result_type as _result_type - - -def argmax(x, /, *, axis=None, keepdims=False): - """Returns the indices of the maximum values along a specified axis.""" - return jax.numpy.argmax(x, axis=axis, keepdims=keepdims) - - -def argmin(x, /, *, axis=None, keepdims=False): - """Returns the indices of the minimum values along a specified axis.""" - return jax.numpy.argmin(x, axis=axis, keepdims=keepdims) - - -def nonzero(x, /): - """Returns the indices of the array elements which are non-zero.""" - if jax.numpy.ndim(x) == 0: - raise ValueError("inputs to nonzero() must have at least one dimension.") - return jax.numpy.nonzero(x) - - -def searchsorted(x1, x2, /, *, side='left', sorter=None): - """ - Finds the indices into x1 such that, if the corresponding elements in x2 - were inserted before the indices, the order of x1, when sorted in ascending - order, would be preserved. - """ - return jax.numpy.searchsorted(x1, x2, side=side, sorter=sorter) - - -def where(condition, x1, x2, /): - """Returns elements chosen from x1 or x2 depending on condition.""" - dtype = _result_type(x1, x2) - return jax.numpy.where(condition, x1.astype(dtype), x2.astype(dtype)) diff --git a/jax/experimental/array_api/_set_functions.py b/jax/experimental/array_api/_set_functions.py deleted file mode 100644 index c9f539d5ec06..000000000000 --- a/jax/experimental/array_api/_set_functions.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax - - -def unique_all(x, /): - """Returns the unique elements of an input array x, the first occurring indices for each unique element in x, the indices from the set of unique elements that reconstruct x, and the corresponding counts for each unique element in x.""" - return jax.numpy.unique_all(x) - - -def unique_counts(x, /): - """Returns the unique elements of an input array x and the corresponding counts for each unique element in x.""" - return jax.numpy.unique_counts(x) - - -def unique_inverse(x, /): - """Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x.""" - return jax.numpy.unique_inverse(x) - - -def unique_values(x, /): - """Returns the unique elements of an input array x.""" - return jax.numpy.unique_values(x) diff --git a/jax/experimental/array_api/_sorting_functions.py b/jax/experimental/array_api/_sorting_functions.py deleted file mode 100644 index 4c64480d39a6..000000000000 --- a/jax/experimental/array_api/_sorting_functions.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -from jax import Array - - -def argsort(x: Array, /, *, axis: int = -1, descending: bool = False, - stable: bool = True) -> Array: - """Returns the indices that sort an array x along a specified axis.""" - return jax.numpy.argsort(x, axis=axis, descending=descending, stable=stable) - - -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, - stable: bool = True) -> Array: - """Returns a sorted copy of an input array x.""" - return jax.numpy.sort(x, axis=axis, descending=descending, stable=stable) diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index 141b80abfb14..c34fb1fc3af4 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -13,47 +13,14 @@ # limitations under the License. import jax -from jax.experimental.array_api._data_type_functions import ( - _promote_to_default_dtype, -) - - -def cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False): - """Calculates the cumulative sum of elements in the input array x.""" - return jax.numpy.cumulative_sum(x, axis=axis, dtype=dtype, include_initial=include_initial) - -def max(x, /, *, axis=None, keepdims=False): - """Calculates the maximum value of the input array x.""" - return jax.numpy.max(x, axis=axis, keepdims=keepdims) - - -def mean(x, /, *, axis=None, keepdims=False): - """Calculates the arithmetic mean of the input array x.""" - return jax.numpy.mean(x, axis=axis, keepdims=keepdims) - - -def min(x, /, *, axis=None, keepdims=False): - """Calculates the minimum value of the input array x.""" - return jax.numpy.min(x, axis=axis, keepdims=keepdims) - - -def prod(x, /, *, axis=None, dtype=None, keepdims=False): - """Calculates the product of input array x elements.""" - x = _promote_to_default_dtype(x) - return jax.numpy.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) - +# TODO(micky774): Remove after deprecating ddof-->correction in jnp.std and +# jnp.var def std(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the standard deviation of the input array x.""" return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims) -def sum(x, /, *, axis=None, dtype=None, keepdims=False): - """Calculates the sum of the input array x.""" - x = _promote_to_default_dtype(x) - return jax.numpy.sum(x, axis=axis, dtype=dtype, keepdims=keepdims) - - def var(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the variance of the input array x.""" return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims) diff --git a/jax/experimental/array_api/_utility_functions.py b/jax/experimental/array_api/_utility_functions.py index 1b4d87f3c852..c5dac25fd8c6 100644 --- a/jax/experimental/array_api/_utility_functions.py +++ b/jax/experimental/array_api/_utility_functions.py @@ -20,15 +20,8 @@ from jax._src.lib import xla_client as xc from jax._src import dtypes as _dtypes, config -def all(x, /, *, axis=None, keepdims=False): - """Tests whether all input array elements evaluate to True along a specified axis.""" - return jax.numpy.all(x, axis=axis, keepdims=keepdims) - - -def any(x, /, *, axis=None, keepdims=False): - """Tests whether any input array element evaluates to True along a specified axis.""" - return jax.numpy.any(x, axis=axis, keepdims=keepdims) - +# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api +# deprecation class __array_namespace_info__: def __init__(self): diff --git a/jax/experimental/array_api/_version.py b/jax/experimental/array_api/_version.py index 4936af86da4c..104df73c77b9 100644 --- a/jax/experimental/array_api/_version.py +++ b/jax/experimental/array_api/_version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__array_api_version__ = '2022.12' +__array_api_version__ = '2023.12' diff --git a/jax/experimental/array_api/fft.py b/jax/experimental/array_api/fft.py index f83d45401d20..a17ed3cec130 100644 --- a/jax/experimental/array_api/fft.py +++ b/jax/experimental/array_api/fft.py @@ -12,19 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.array_api._fft_functions import ( +from jax.numpy.fft import ( fft as fft, - fftfreq as fftfreq, - fftn as fftn, - fftshift as fftshift, - hfft as hfft, ifft as ifft, - ifftn as ifftn, - ifftshift as ifftshift, - ihfft as ihfft, + rfft as rfft, irfft as irfft, + fftn as fftn, + ifftn as ifftn, + rfftn as rfftn, irfftn as irfftn, - rfft as rfft, + hfft as hfft, + ihfft as ihfft, + fftshift as fftshift, + ifftshift as ifftshift, +) + +from jax.experimental.array_api._fft_functions import ( + fftfreq as fftfreq, rfftfreq as rfftfreq, - rfftn as rfftn, ) diff --git a/jax/experimental/array_api/linalg.py b/jax/experimental/array_api/linalg.py index 49c93c5b1908..2e6b7ec8ee57 100644 --- a/jax/experimental/array_api/linalg.py +++ b/jax/experimental/array_api/linalg.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.array_api._linear_algebra_functions import ( +from jax.numpy.linalg import ( cholesky as cholesky, cross as cross, det as det, @@ -22,18 +22,23 @@ inv as inv, matmul as matmul, matrix_norm as matrix_norm, - matrix_power as matrix_power, - matrix_rank as matrix_rank, matrix_transpose as matrix_transpose, outer as outer, - pinv as pinv, qr as qr, slogdet as slogdet, solve as solve, svd as svd, svdvals as svdvals, tensordot as tensordot, - trace as trace, vecdot as vecdot, vector_norm as vector_norm, + matrix_power as matrix_power, +) + +# TODO(micky774): Add trace to jax.numpy.linalg +from jax.numpy import trace as trace + +from jax.experimental.array_api._linear_algebra_functions import ( + matrix_rank as matrix_rank, + pinv as pinv, ) diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index c502f4ac5999..2fb90b593add 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -11,3 +11,9 @@ array_api_tests/test_special_cases.py::test_unary # fft test suite is buggy as of 83f0bcdc array_api_tests/test_fft.py + +# Pending implementation update for proper dtype promotion behavior, +# see https://github.com/data-apis/array-api-tests/issues/234 +array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod +array_api_tests/test_linalg.py::test_trace \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 300a2f3dae8e..225c79e0368f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,11 @@ filterwarnings = [ "ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", "ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning", "ignore:The host_callback APIs are deprecated .*:DeprecationWarning", + "ignore: