Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updating JAX_FFT cuFFTMP to work with JAX 0.4.30 #198

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cuFFTMp/JAX_FFT/src/cufftmp_jax/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@ find_package(pybind11 CONFIG REQUIRED)

include_directories(${CMAKE_CURRENT_LIST_DIR}/lib)

set(NVSHMEM_HOME $ENV{NVHPC_ROOT}/comm_libs/12.2/nvshmem_cufftmp_compat)
set(CUFFTMP_HOME $ENV{NVHPC_ROOT}/math_libs)
message(STATUS "Using ${NVSHMEM_HOME} for NVSHMEM_HOME and ${CUFFTMP_HOME} for CUFFTMP_HOME")
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${CUFFTMP_HOME}/include ${NVSHMEM_HOME}/include)


include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
${CUFFTMP_HOME}/include ${NVSHMEM_HOME}/include
$ENV{CUFFT_INC}
)
link_directories(${CUFFTMP_HOME}/lib ${NVSHMEM_HOME}/lib)

pybind11_add_module(gpu_ops
Expand Down
34 changes: 18 additions & 16 deletions cuFFTMp/JAX_FFT/src/cufftmp_jax/cufftmp_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from jax.lib import xla_client
from jax import core, dtypes
from jax.interpreters import xla, mlir
from jax.abstract_arrays import ShapedArray
from jax._src.sharding import NamedSharding
from jax.core import ShapedArray
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jaxlib.hlo_helpers import custom_call

Expand All @@ -30,7 +30,7 @@
def _cufftmp_bind(input, num_parts, dist, dir):

# param=val means it's a static parameter
(output,) = _cufftmp_prim.bind(input,
output = _cufftmp_prim.bind(input,
num_parts=num_parts,
dist=dist,
dir=dir)
Expand Down Expand Up @@ -110,7 +110,7 @@ def cufftmp(x, dist, dir):

@custom_partitioning
def _cufftmp_(x):
return _cufftmp_bind(x, num_parts=1, dist=dist, dir=dir)
return _cufftmp_bind(x, num_parts=jax.device_count(), dist=dist, dir=dir)

_cufftmp_.def_partition(
infer_sharding_from_operands=partial(
Expand Down Expand Up @@ -180,18 +180,20 @@ def _cufftmp_translation(ctx, input, num_parts, dist, dir):
else:
raise ValueError("Unsupported tensor rank; must be 2 or 3")

return [custom_call(
"gpu_cufftmp",
# Output types
out_types=[output_type],
# The inputs:
operands=[input,],
# Layout specification:
operand_layouts=[layout,],
result_layouts=[layout,],
# GPU specific additional data
backend_config=opaque
)]
out = custom_call(
"gpu_cufftmp",
# Output types
result_types=[output_type],
# The inputs:
operands=[input,],
# Layout specification:
operand_layouts=[layout,],
result_layouts=[layout,],
# GPU specific additional data
backend_config=opaque
)

return out.results


# *********************************************
Expand Down
2 changes: 1 addition & 1 deletion cuFFTMp/JAX_FFT/src/fft_common/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum

import jax
from jax.experimental import PartitionSpec
from jax.sharding import PartitionSpec


class Dist(Enum):
Expand Down
2 changes: 1 addition & 1 deletion cuFFTMp/JAX_FFT/src/xfft/xfft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial

import jax
from jax._src.sharding import NamedSharding
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from fft_common import Dir

Expand Down
18 changes: 12 additions & 6 deletions cuFFTMp/JAX_FFT/tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from fft_common import Dist, Dir
from cufftmp_jax import cufftmp
from xfft import xfft

from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
import helpers
from jax.experimental import mesh_utils, multihost_utils


def main():
Expand Down Expand Up @@ -45,10 +47,17 @@ def main():
raise ValueError(f"Wrong implementation: got {impl}, expected cufftmp or xfft")

dist = Dist.create(opt['dist'])
if dist == Dist.SLABS_X:
pdims = [jax.device_count(), 1]
axis_names = ('gpus', None)
elif dist == Dist.SLABS_Y:
pdims = [1, jax.device_count()]
axis_names = (None, 'gpus')
input_shape = dist.slab_shape(fft_dims)
dtype = jnp.complex64

mesh = maps.Mesh(np.asarray(jax.devices()), ('gpus',))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=axis_names)

with jax.spmd_mode('allow_all'):

Expand All @@ -60,10 +69,7 @@ def main():

with mesh:

fft = pjit(dist_fft,
in_axis_resources=None,
out_axis_resources=None,
static_argnums=[1, 2])
fft = jax.jit(dist_fft,static_argnums=[1, 2])

output = fft(input, dist, Dir.FWD)

Expand Down