Skip to content

Commit

Permalink
Refactor array_api namespace, relying more directly on jax.numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 30, 2024
1 parent a949ce7 commit 5c0b3de
Show file tree
Hide file tree
Showing 23 changed files with 171 additions and 1,145 deletions.
7 changes: 6 additions & 1 deletion jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,10 @@ def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bo
options.update(_dtype_kinds[kind])
elif isinstance(kind, np.dtype):
options.add(kind)
# Check for _ScalarMeta without referencing the class directly
elif issubclass(kind.__class__, type) and isinstance(getattr(kind, 'dtype'), np.dtype):
options.add(kind.dtype)
else:
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
return the_dtype in options

Expand Down Expand Up @@ -652,6 +654,9 @@ def check_valid_dtype(dtype: DType) -> None:
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")

def is_valid_dtype(dtype: DType) -> bool:
return dtype in _jax_dtype_set

def dtype(x: Any, *, canonicalize: bool = False) -> DType:
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
if x is None:
Expand Down
164 changes: 72 additions & 92 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2022.12'
'2023.12'
>>> arr = xp.arange(1000)
Expand All @@ -38,64 +38,17 @@

from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__

from jax.experimental.array_api import (
fft as fft,
linalg as linalg,
)
from jax.experimental.array_api import fft as fft
from jax.experimental.array_api import linalg as linalg

from jax.experimental.array_api._constants import (
from jax.numpy import (
e as e,
inf as inf,
nan as nan,
newaxis as newaxis,
pi as pi,
)

from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
empty as empty,
empty_like as empty_like,
eye as eye,
from_dlpack as from_dlpack,
full as full,
full_like as full_like,
linspace as linspace,
meshgrid as meshgrid,
ones as ones,
ones_like as ones_like,
tril as tril,
triu as triu,
zeros as zeros,
zeros_like as zeros_like,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
can_cast as can_cast,
finfo as finfo,
iinfo as iinfo,
isdtype as isdtype,
result_type as result_type,
)

from jax.experimental.array_api._dtypes import (
bool as bool,
int8 as int8,
int16 as int16,
int32 as int32,
int64 as int64,
uint8 as uint8,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
float32 as float32,
float64 as float64,
complex64 as complex64,
complex128 as complex128,
)

from jax.experimental.array_api._elementwise_functions import (
abs as abs,
acos as acos,
acosh as acosh,
Expand All @@ -111,8 +64,6 @@
bitwise_or as bitwise_or,
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
clip as clip,
conj as conj,
copysign as copysign,
cos as cos,
Expand All @@ -121,11 +72,9 @@
equal as equal,
exp as exp,
expm1 as expm1,
floor as floor,
floor_divide as floor_divide,
greater as greater,
greater_equal as greater_equal,
hypot as hypot,
imag as imag,
isfinite as isfinite,
isinf as isinf,
Expand All @@ -137,10 +86,7 @@
log1p as log1p,
log2 as log2,
logaddexp as logaddexp,
logical_and as logical_and,
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
maximum as maximum,
minimum as minimum,
multiply as multiply,
Expand All @@ -151,23 +97,15 @@
real as real,
remainder as remainder,
round as round,
sign as sign,
signbit as signbit,
sin as sin,
sinh as sinh,
sqrt as sqrt,
square as square,
subtract as subtract,
tan as tan,
tanh as tanh,
trunc as trunc,
)

from jax.experimental.array_api._indexing_functions import (
take as take,
)

from jax.experimental.array_api._manipulation_functions import (
tanh as tanh,
broadcast_arrays as broadcast_arrays,
broadcast_to as broadcast_to,
concat as concat,
Expand All @@ -176,56 +114,98 @@
moveaxis as moveaxis,
permute_dims as permute_dims,
repeat as repeat,
reshape as reshape,
roll as roll,
squeeze as squeeze,
stack as stack,
tile as tile,
unstack as unstack,
)

from jax.experimental.array_api._searching_functions import (
argmax as argmax,
argmin as argmin,
nonzero as nonzero,
searchsorted as searchsorted,
where as where,
)

from jax.experimental.array_api._set_functions import (
unique_all as unique_all,
unique_counts as unique_counts,
unique_inverse as unique_inverse,
unique_values as unique_values,
)

from jax.experimental.array_api._sorting_functions import (
argsort as argsort,
sort as sort,
)

from jax.experimental.array_api._statistical_functions import (
cumulative_sum as cumulative_sum,
max as max,
mean as mean,
min as min,
prod as prod,
std as std,
sum as sum,
var as var
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
all as all,
any as any,
)

from jax.experimental.array_api._linear_algebra_functions import (
from_dlpack as from_dlpack,
meshgrid as meshgrid,
empty as empty,
empty_like as empty_like,
full as full,
full_like as full_like,
ones as ones,
ones_like as ones_like,
zeros as zeros,
zeros_like as zeros_like,
can_cast as can_cast,
isdtype as isdtype,
result_type as result_type,
iinfo as iinfo,
sign as sign,
nonzero as nonzero,
prod as prod,
sum as sum,
matmul as matmul,
matrix_transpose as matrix_transpose,
tensordot as tensordot,
vecdot as vecdot,
bool as bool,
int8 as int8,
int16 as int16,
int32 as int32,
int64 as int64,
uint8 as uint8,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
float32 as float32,
float64 as float64,
complex64 as complex64,
complex128 as complex128,
)

from jax.experimental.array_api._manipulation_functions import (
reshape as reshape,
)

from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
eye as eye,
linspace as linspace,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
finfo as finfo,
)

from jax.experimental.array_api._elementwise_functions import (
ceil as ceil,
clip as clip,
floor as floor,
hypot as hypot,
logical_and as logical_and,
logical_or as logical_or,
logical_xor as logical_xor,
trunc as trunc,
)

from jax.experimental.array_api._statistical_functions import (
std as std,
var as var,
)

from jax.experimental.array_api._utility_functions import (
__array_namespace_info__ as __array_namespace_info__,
)

from jax.experimental.array_api import _array_methods
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Any, Callable
from typing import Any

import jax
from jax._src.array import ArrayImpl
Expand Down
21 changes: 0 additions & 21 deletions jax/experimental/array_api/_constants.py

This file was deleted.

39 changes: 1 addition & 38 deletions jax/experimental/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,16 @@

import jax
import jax.numpy as jnp
from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding

# TODO(micky774): Deprecate after adding device argument to jax.numpy functions
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)

def asarray(obj, /, *, dtype=None, device=None, copy=None):
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)

def empty(shape, *, dtype=None, device=None):
return jax.device_put(jnp.empty(shape, dtype=dtype), device=device)

def empty_like(x, /, *, dtype=None, device=None):
return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device)

def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)

def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
return jnp.from_dlpack(x, device=device, copy=copy)

def full(shape, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)

def full_like(x, /, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device)

def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)

def meshgrid(*arrays, indexing='xy'):
return jnp.meshgrid(*arrays, indexing=indexing)

def ones(shape, *, dtype=None, device=None):
return jax.device_put(jnp.ones(shape, dtype=dtype), device=device)

def ones_like(x, /, *, dtype=None, device=None):
return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device)

def tril(x, /, *, k=0):
return jnp.tril(x, k=k)

def triu(x, /, *, k=0):
return jnp.triu(x, k=k)

def zeros(shape, *, dtype=None, device=None):
return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device)

def zeros_like(x, /, *, dtype=None, device=None):
return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device)
Loading

0 comments on commit 5c0b3de

Please sign in to comment.