From de47856c86896232fb28c242c71a2e6e0cfa6ccd Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 8 Nov 2024 14:15:15 +0300 Subject: [PATCH] minor: cleanup of kirchhoff and _kirchhoff_cuda --- docs/source/api/index.rst | 3 +- pylops/waveeqprocessing/_kirchhoff_cuda.py | 514 +++++++++++---------- pylops/waveeqprocessing/kirchhoff.py | 52 ++- 3 files changed, 308 insertions(+), 261 deletions(-) diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 59e40e51..c11548f2 100755 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -62,7 +62,8 @@ Basic operators Real Imag Conj - + ToCupy + Smoothing and derivatives ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pylops/waveeqprocessing/_kirchhoff_cuda.py b/pylops/waveeqprocessing/_kirchhoff_cuda.py index 18aa720c..a5e55b29 100644 --- a/pylops/waveeqprocessing/_kirchhoff_cuda.py +++ b/pylops/waveeqprocessing/_kirchhoff_cuda.py @@ -1,76 +1,108 @@ -from numba import cuda -import numpy as np from math import cos +import numpy as np +from numba import cuda + +from pylops.utils.backend import to_cupy + + +class _KirchhoffCudaHelper: + """A helper class for performing Kirchhoff migration or modeling using Numba CUDA. + + This class provides methods to compute the forward and adjoint operations for the + Kirchhoff operator, utilizing GPU acceleration through Numba's CUDA capabilities. -class _kirchhoffCudaHelper: - """ A helper class for performing Kirchhoff migration or modeling using CUDA via Numba. - - This class provides methods to compute the forward and adjoint operations for Kirchhoff migration, - utilizing GPU acceleration through Numba's CUDA capabilities. - - Parameters - ---------- - ns : int - Number of sources. - nr : int - Number of receivers. - nt : int - Number of time samples. - ni : int - Number of image points. - dynamic : int, optional - Flag indicating whether to use dynamic computation. ``True`` == 1 or not ``False`` == 0 (default is 0). - travsrcrec : int, optional - Flag indicating whether to use separate tables for src and rec traveltimes. Seperate == 1 (default is 0). - """ - - def __init__(self, ns, nr, nt, ni, dynamic=0, travsrcrec=0): - self.dynamic, self.travsrcrec = dynamic, travsrcrec + Parameters + ---------- + ns : :obj:`int` + Number of sources. + nr : :obj:`int` + Number of receivers. + nt : :obj:`int` + Number of time samples. + ni : :obj:`int` + Number of image points. + dynamic : :obj:`int`, optional + Flag indicating whether to use dynamic computation. ``True`` == 1 or not ``False`` == 0 (default is 0). + nstreams : :obj:`int`, optional + Number of streams used in case of ``dynamic=True``. + + """ + + def __init__(self, ns, nr, nt, ni, dynamic=0, nstreams=1): self.ns, self.nr, self.nt, self.ni = ns, nr, nt, ni + self.dynamic = dynamic + self.nstreams = nstreams - self._lunch_grid_setup() + self._grid_setup() - def _lunch_grid_setup(self): - """ Set up the CUDA grid and block dimensions based on the current device and computation flags. - This method configures the number of blocks and threads per block for CUDA kernels, - depending on whether dynamic computation and travel times from sources and receivers are used. + def _grid_setup(self): + """Set up the CUDA grid and block dimensions based on the current device and computation flags. + + This method configures the number of blocks and threads per block for CUDA kernels, + depending on whether dynamic weights are ued in computations or not. """ current_device = cuda.get_current_device() if self.dynamic: num_sources = self.ns - num_streams = 3 + num_streams = self.nstreams streams = [cuda.stream() for _ in range(num_streams)] sources_per_stream = num_sources // num_streams remainder = num_sources % num_streams source_counter = 0 self.sources_per_streams = {} for i, stream in enumerate(streams): - num_sources_for_stream = sources_per_stream + (1 if i < remainder else 0) - sources_for_stream = list(range(source_counter, source_counter + num_sources_for_stream)) + num_sources_for_stream = sources_per_stream + ( + 1 if i < remainder else 0 + ) + sources_for_stream = list( + range(source_counter, source_counter + num_sources_for_stream) + ) self.sources_per_streams[stream] = sources_for_stream source_counter += num_sources_for_stream - n_runs = num_streams * self.ns * self.ni * self.nr # number_of_times_to_run_kernel + n_runs = ( + num_streams * self.ns * self.ni * self.nr + ) # number_of_times_to_run_kernel self.num_threads_per_blocks = current_device.WARP_SIZE * 8 - self.num_blocks = (n_runs + (self.num_threads_per_blocks - 1)) // self.num_threads_per_blocks - # print(num_threads_per_blocks) - # print(self.blocks) + self.num_blocks = ( + n_runs + (self.num_threads_per_blocks - 1) + ) // self.num_threads_per_blocks else: - if not self.travsrcrec: - # version 4 - self.num_threads_per_blocks = current_device.WARP_SIZE * 8 - self.num_blocks = ((self.ns * self.nr) + ( - self.num_threads_per_blocks - 1)) // self.num_threads_per_blocks - else: - # version 3 - wrap = current_device.WARP_SIZE - multipr_count = current_device.MULTIPROCESSOR_COUNT - self.num_threads_per_blocks = (wrap, wrap) - self.num_blocks = (multipr_count, multipr_count) - - def _data_prep_dynamic(self, ns, nr, nt, ni, nz, dt, aperture, angleaperture, aperturetap, vel, six, - rix, trav_recs, angle_recs, trav_srcs, angle_srcs, amp_srcs, amp_recs): - """ Prepare data for dynamic computation by transfering some variables to device memory in advance once.""" + warp = current_device.WARP_SIZE + self.num_threads_per_blocks = (warp, warp) + self.num_blocks = ( + (self.ns + self.num_threads_per_blocks[0] - 1) + // self.num_threads_per_blocks[0], + (self.nr + self.num_threads_per_blocks[1] - 1) + // self.num_threads_per_blocks[1], + ) + + def _data_prep_dynamic( + self, + ns, + nr, + nt, + ni, + nz, + dt, + aperture, + angleaperture, + aperturetap, + vel, + six, + rix, + trav_recs, + angle_recs, + trav_srcs, + angle_srcs, + amp_srcs, + amp_recs, + ): + """Data preparation for dynamic computation + + Prepare data for dynamic computation by transfering some variables to device + memory in advance once. + """ ns_d = np.int32(ns) nr_d = np.int32(nr) nt_d = np.int32(nt) @@ -85,8 +117,22 @@ def _data_prep_dynamic(self, ns, nr, nt, ni, nz, dt, aperture, angleaperture, ap vel_d = cuda.to_device(vel) six_d = cuda.to_device(six) rix_d = cuda.to_device(rix) - self.const_inputs = (ns_d, nr_d, nt_d, ni_d, nz_d, dt_d, aperturemin_d, aperturemax_d, angleaperturemin_d, - angleaperturemax_d, vel_d, aperturetap_d, six_d, rix_d) + self.const_inputs = ( + ns_d, + nr_d, + nt_d, + ni_d, + nz_d, + dt_d, + aperturemin_d, + aperturemax_d, + angleaperturemin_d, + angleaperturemax_d, + vel_d, + aperturetap_d, + six_d, + rix_d, + ) self.trav_recs_d_global = cuda.to_device(trav_recs) self.angles_recs_d_global = cuda.to_device(angle_recs) @@ -95,41 +141,6 @@ def _data_prep_dynamic(self, ns, nr, nt, ni, nz, dt, aperture, angleaperture, ap self.amp_srcs_d_global = cuda.to_device(amp_srcs) self.amp_recs_d_global = cuda.to_device(amp_recs) - @staticmethod - @cuda.jit - def _trav_kirch_matvec_cuda(x, y, nsnr, nt, ni, itrav, travd): - isrcrec = cuda.grid(1) - if nsnr > isrcrec: - for ii in range(ni): - itravisrcrec = itrav[:, isrcrec] - travdisrcrec = travd[:, isrcrec] - itravii = itravisrcrec[ii] - travdii = travdisrcrec[ii] - if 0 <= itravii < nt - 1: - ind1 = (isrcrec, itravii) - val1 = x[ii] * (1 - travdii) - ind2 = (isrcrec, itravii + 1) - val2 = x[ii] * travdii - cuda.atomic.add(y, ind1, val1) - cuda.atomic.add(y, ind2, val2) - - @staticmethod - @cuda.jit - def _trav_kirch_rmatvec_cuda(x, y, nsnr, nt, ni, itrav, travd): - isrcrec = cuda.grid(1) - if nsnr > isrcrec: - for ii in range(ni): - itravii = itrav[ii] - travdii = travd[ii] - itravisrcrecii = itravii[isrcrec] - travdisrcrecii = travdii[isrcrec] - if 0 <= itravisrcrecii < nt - 1: - vii = ( - x[isrcrec, itravisrcrecii] * (1 - travdisrcrecii) - + x[isrcrec, itravisrcrecii + 1] * travdisrcrecii - ) - cuda.atomic.add(y, ii, vii) - @staticmethod @cuda.jit def _travsrcrec_kirch_matvec_cuda(x, y, ns, nr, nt, ni, dt, trav_srcs, trav_recs): @@ -152,7 +163,7 @@ def _travsrcrec_kirch_matvec_cuda(x, y, ns, nr, nt, ni, dt, trav_srcs, trav_recs @staticmethod @cuda.jit def _travsrcrec_kirch_rmatvec_cuda(x, y, ns, nr, nt, ni, dt, trav_srcs, trav_recs): - irec, isrc = cuda.grid(2) + isrc, irec = cuda.grid(2) if ns > isrc and nr > irec: for ii in range(ni): trav_srcsii = trav_srcs[ii] @@ -164,20 +175,39 @@ def _travsrcrec_kirch_rmatvec_cuda(x, y, ns, nr, nt, ni, dt, trav_srcs, trav_rec travdii = travii / dt - itravii if 0 <= itravii < nt - 1: vii = ( - x[isrc * nr + irec, itravii] * (1 - travdii) - + x[isrc * nr + irec, itravii + 1] * travdii + x[isrc * nr + irec, itravii] * (1 - travdii) + + x[isrc * nr + irec, itravii + 1] * travdii ) cuda.atomic.add(y, ii, vii) @staticmethod @cuda.jit - def _ampsrcrec_kirch_matvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, aperturemax, angleaperturemin, - angleaperturemax, - vel, aperturetap, - six_d, rix_d, - travsrc, ampsrc, anglesrc, - travi, ampi, anglei, - y, isrc_list, irec, x): + def _ampsrcrec_kirch_matvec_cuda_streams( + ns, + nr, + nt, + ni, + nz, + dt, + aperturemin, + aperturemax, + angleaperturemin, + angleaperturemax, + vel, + aperturetap, + six_d, + rix_d, + travsrc, + ampsrc, + anglesrc, + travi, + ampi, + anglei, + y, + isrc_list, + irec, + x, + ): ii = cuda.grid(1) if ni > ii: index_isrc = -1 @@ -204,36 +234,35 @@ def _ampsrcrec_kirch_matvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, ap angle_rec = angleirec[ii] abs_angle_src = abs(angle_src) abs_angle_rec = abs(angle_rec) - abs_angle_src_rec = abs(angle_src + angle_rec) aptap = 1.0 # angle apertures checks if ( - abs_angle_src < angleaperturemax - and abs_angle_rec < angleaperturemax + abs_angle_src < angleaperturemax + and abs_angle_rec < angleaperturemax ): if abs_angle_src >= angleaperturemin: # extract source angle aperture taper value aptap = ( - aptap - * aperturetap[ - int( - 20 - * (abs_angle_src - angleaperturemin) - // dangleaperture - ) - ] + aptap + * aperturetap[ + int( + 20 + * (abs_angle_src - angleaperturemin) + // dangleaperture + ) + ] ) if abs_angle_rec >= angleaperturemin: # extract receiver angle aperture taper value aptap = ( - aptap - * aperturetap[ - int( - 20 - * (abs_angle_rec - angleaperturemin) - // dangleaperture - ) - ] + aptap + * aperturetap[ + int( + 20 + * (abs_angle_rec - angleaperturemin) + // dangleaperture + ) + ] ) # identify x-index of image point iz = ii % nz @@ -243,12 +272,10 @@ def _ampsrcrec_kirch_matvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, ap if aperture >= aperturemin: # extract aperture taper value aptap = ( - aptap - * aperturetap[ - int( - 20 * ((aperture - aperturemin) // daperture) - ) - ] + aptap + * aperturetap[ + int(20 * ((aperture - aperturemin) // daperture)) + ] ) # time limit check if 0 <= itravii < nt - 1: @@ -261,13 +288,32 @@ def _ampsrcrec_kirch_matvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, ap @staticmethod @cuda.jit - def _ampsrcrec_kirch_rmatvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, aperturemax, angleaperturemin, - angleaperturemax, - vel, aperturetap, - six_d, rix_d, - travsrc, ampsrc, anglesrc, - travi, ampi, anglei, - y, isrc_list, irec, x): + def _ampsrcrec_kirch_rmatvec_cuda_streams( + ns, + nr, + nt, + ni, + nz, + dt, + aperturemin, + aperturemax, + angleaperturemin, + angleaperturemax, + vel, + aperturetap, + six_d, + rix_d, + travsrc, + ampsrc, + anglesrc, + travi, + ampi, + anglei, + y, + isrc_list, + irec, + x, + ): ii = cuda.grid(1) if ni > ii: index_isrc = -1 @@ -291,39 +337,38 @@ def _ampsrcrec_kirch_rmatvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, a angle_rec = angleirec[ii] abs_angle_src = abs(angle_src) abs_angle_rec = abs(angle_rec) - abs_angle_src_rec = abs(angle_src + angle_rec) aptap = 1.0 cosangle = cos((angle_src - angle_rec) / 2.0) damp = 2.0 * cosangle * ampisrc[ii] * ampirec[ii] / vel[ii] # angle apertures checks if ( - abs_angle_src < angleaperturemax - and abs_angle_rec < angleaperturemax + abs_angle_src < angleaperturemax + and abs_angle_rec < angleaperturemax ): if abs_angle_src >= angleaperturemin: # extract source angle aperture taper value aptap = ( - aptap - * aperturetap[ - int( - 20 - * (abs_angle_src - angleaperturemin) - // dangleaperture - ) - ] + aptap + * aperturetap[ + int( + 20 + * (abs_angle_src - angleaperturemin) + // dangleaperture + ) + ] ) if abs_angle_rec >= angleaperturemin: # extract receiver angle aperture taper value aptap = ( - aptap - * aperturetap[ - int( - 20 - * (abs_angle_rec - angleaperturemin) - // dangleaperture - ) - ] + aptap + * aperturetap[ + int( + 20 + * (abs_angle_rec - angleaperturemin) + // dangleaperture + ) + ] ) # identify x-index of image point iz = ii % nz @@ -333,23 +378,21 @@ def _ampsrcrec_kirch_rmatvec_cuda_streams(ns, nr, nt, ni, nz, dt, aperturemin, a if aperture >= aperturemin: # extract aperture taper value aptap = ( - aptap - * aperturetap[ - int( - 20 * ((aperture - aperturemin) // daperture) - ) - ] + aptap + * aperturetap[ + int(20 * ((aperture - aperturemin) // daperture)) + ] ) # time limit check if 0 <= itravii < nt - 1: ind1 = ii val1 = ( - ( - x[isrc * nr + irec, itravii] * (1 - travdii) - + x[isrc * nr + irec, itravii + 1] * travdii - ) - * damp - * aptap + ( + x[isrc * nr + irec, itravii] * (1 - travdii) + + x[isrc * nr + irec, itravii + 1] * travdii + ) + * damp + * aptap ) cuda.atomic.add(y, ind1, val1) @@ -357,19 +400,20 @@ def _process_streams(self, x, opt): """Process data using CUDA streams for dynamic computation. This method handles data preparation and execution of CUDA kernels using streams, - for both forward ('_matvec') and adjoint ('_rmatvec') operations. + for both forward and adjoint operations. Parameters ---------- - x : ndarray + x : :obj:`numpy.ndarray` Input data (image or seismic data). - opt : str + opt : :obj:`str` Operation type, either '_matvec' for forward or '_rmatvec' for adjoint. Returns ------- - y : ndarray + y : :obj:`numpy.ndarray` Output data after processing. + """ if opt == "_matvec": @@ -390,7 +434,9 @@ def _process_streams(self, x, opt): for irec in range(self.nr): for stream, isrc_list in self.sources_per_streams.items(): if opt == "_matvec": - self._ampsrcrec_kirch_matvec_cuda_streams[self.num_blocks, self.num_threads_per_blocks, stream]( + self._ampsrcrec_kirch_matvec_cuda_streams[ + self.num_blocks, self.num_threads_per_blocks, stream + ]( *self.const_inputs, self.trav_srcs_d_global, self.amp_srcs_d_global, @@ -399,9 +445,14 @@ def _process_streams(self, x, opt): self.amp_recs_d_global, self.angles_recs_d_global, y_d_dict[stream], - isrc_list_d_dict[stream], irec, x_d) + isrc_list_d_dict[stream], + irec, + x_d + ) elif opt == "_rmatvec": - self._ampsrcrec_kirch_rmatvec_cuda_streams[self.num_blocks, self.num_threads_per_blocks, stream]( + self._ampsrcrec_kirch_rmatvec_cuda_streams[ + self.num_blocks, self.num_threads_per_blocks, stream + ]( *self.const_inputs, self.trav_srcs_d_global, self.amp_srcs_d_global, @@ -410,110 +461,85 @@ def _process_streams(self, x, opt): self.amp_recs_d_global, self.angles_recs_d_global, y_d_dict[stream], - isrc_list_d_dict[stream], irec, x_d) + isrc_list_d_dict[stream], + irec, + x_d + ) # Synchronize the streams to ensure all operations have been completed for stream in self.sources_per_streams.keys(): stream.synchronize() y_streams = [] - # for idx, stream in enumerate(self.streams): for stream, y_dev in y_d_dict.items(): - # print("synchronize") y_streams.append(y_dev.copy_to_host(stream=stream)) - # print("Done synchronize") - # print("Done Done synchronize") y_total = np.sum(y_streams, axis=0) return y_total def _matvec_call(self, *inputs): """Handle the forward operation call, dispatching to appropriate CUDA kernels. - This method selects the appropriate kernel to execute based on the computation flags, - and performs the forward operation (matrix-vector multiplication). + This method selects the appropriate kernel to execute based on the computation flags, + and performs the forward operation. + + Parameters + ---------- + *inputs : :obj:`list` + List of input parameters required by the kernels. - Parameters - ---------- - *inputs : list - List of input parameters required by the kernels. + Returns + ------- + y : :obj:`numpy.ndarray` + Output data (seismic data) of forward operation. - Returns - ------- - y : ndarray - Output data (seismic data) after forward operation. """ - if self.dynamic and self.travsrcrec: - y = self._process_streams(inputs[0], "_matvec") - elif self.travsrcrec: # len(inputs) == 9 - x_d = cuda.to_device(inputs[0]) - y_d = cuda.to_device(inputs[1]) + if self.dynamic: + y_d = self._process_streams(inputs[0], "_matvec") + else: + x_d = inputs[0] + y_d = inputs[1] ns_d = np.int32(inputs[2]) nr_d = np.int32(inputs[3]) nt_d = np.int32(inputs[4]) ni_d = np.int32(inputs[5]) dt_d = np.float32(inputs[6]) - trav_srcs_d = cuda.to_device(inputs[7]) - trav_recs_d = cuda.to_device(inputs[8]) - self._travsrcrec_kirch_matvec_cuda[self.num_blocks, self.num_threads_per_blocks](x_d, y_d, ns_d, nr_d, nt_d, - ni_d, dt_d, - trav_srcs_d, trav_recs_d) - elif not self.travsrcrec: # len(inputs) == 7: - x_d = cuda.to_device(inputs[0]) - y_d = cuda.to_device(inputs[1]) - nsnr_d = np.int32(inputs[2]) - nt_d = np.int32(inputs[3]) - ni_d = np.int32(inputs[4]) - itrav_d = cuda.to_device(inputs[5]) - travd_d = cuda.to_device(inputs[6]) - self._trav_kirch_matvec_cuda[self.num_blocks, self.num_threads_per_blocks](x_d, y_d, nsnr_d, nt_d, ni_d, - itrav_d, travd_d) - - if not self.dynamic: + trav_srcs_d = to_cupy(inputs[7]) + trav_recs_d = to_cupy(inputs[8]) + self._travsrcrec_kirch_matvec_cuda[ + self.num_blocks, self.num_threads_per_blocks + ](x_d, y_d, ns_d, nr_d, nt_d, ni_d, dt_d, trav_srcs_d, trav_recs_d) cuda.synchronize() - y = y_d.copy_to_host() - return y + return y_d def _rmatvec_call(self, *inputs): - """ Handle the adjoint operation call, dispatching to appropriate CUDA kernels. + """Handle the adjoint operation call, dispatching to appropriate CUDA kernels. + + This method selects the appropriate kernel to execute based on the computation flags, + and performs the adjoint operation. - This method selects the appropriate kernel to execute based on the computation flags, - and performs the adjoint operation (matrix-vector multiplication with the transpose). + Parameters + ---------- + *inputs : :obj:`list` + List of input parameters required by the kernels. - Parameters - ---------- - *inputs : list - List of input parameters required by the kernels. + Returns + ------- + y : :obj:`numpy.ndarray` + Output data (image) of adjoint operation. - Returns - ------- - y : ndarray - Output data (image) after adjoint operation. """ - if self.dynamic and self.travsrcrec: - y = self._process_streams(inputs[0], "_rmatvec") - elif self.travsrcrec: # len(inputs) == 9 - x_d = cuda.to_device(inputs[0]) - y_d = cuda.to_device(inputs[1]) + if self.dynamic: + y_d = self._process_streams(inputs[0], "_rmatvec") + else: + x_d = inputs[0] + y_d = inputs[1] ns_d = np.int32(inputs[2]) nr_d = np.int32(inputs[3]) nt_d = np.int32(inputs[4]) ni_d = np.int32(inputs[5]) dt_d = np.float32(inputs[6]) - trav_srcs_d = cuda.to_device(inputs[7]) - trav_recs_d = cuda.to_device(inputs[8]) - self._travsrcrec_kirch_rmatvec_cuda[self.num_blocks, self.num_threads_per_blocks](x_d, y_d, ns_d, nr_d, - nt_d, ni_d, dt_d, - trav_srcs_d, trav_recs_d) - elif not self.travsrcrec: # len(inputs) == 7: - x_d = cuda.to_device(inputs[0]) - y_d = cuda.to_device(inputs[1]) - nsnr_d = np.int32(inputs[2]) - nt_d = np.int32(inputs[3]) - ni_d = np.int32(inputs[4]) - itrav_d = cuda.to_device(inputs[5]) - travd_d = cuda.to_device(inputs[6]) - self._trav_kirch_rmatvec_cuda[self.num_blocks, self.num_threads_per_blocks](x_d, y_d, nsnr_d, nt_d, ni_d, - itrav_d, travd_d) - - if not self.dynamic: + trav_srcs_d = to_cupy(inputs[7]) + trav_recs_d = to_cupy(inputs[8]) + self._travsrcrec_kirch_rmatvec_cuda[ + self.num_blocks, self.num_threads_per_blocks + ](x_d, y_d, ns_d, nr_d, nt_d, ni_d, dt_d, trav_srcs_d, trav_recs_d) cuda.synchronize() - y = y_d.copy_to_host() - return y + return y_d diff --git a/pylops/waveeqprocessing/kirchhoff.py b/pylops/waveeqprocessing/kirchhoff.py index 4ba7b939..43727313 100644 --- a/pylops/waveeqprocessing/kirchhoff.py +++ b/pylops/waveeqprocessing/kirchhoff.py @@ -12,6 +12,7 @@ from pylops.signalprocessing import Convolve1D from pylops.utils import deps from pylops.utils._internal import _value_or_sized_to_array +from pylops.utils.backend import get_array_module from pylops.utils.decorators import reshaped from pylops.utils.tapers import taper from pylops.utils.typing import DTypeLike, NDArray @@ -25,7 +26,8 @@ if jit_message is None: from numba import jit, prange - from ._kirchhoff_cuda import _kirchhoffCudaHelper + from ._kirchhoff_cuda import _KirchhoffCudaHelper + # detect whether to use parallel or not numba_threads = int(os.getenv("NUMBA_NUM_THREADS", "1")) parallel = True if numba_threads != 1 else False @@ -83,8 +85,8 @@ class Kirchhoff(LinearOperator): :math:`\lbrack (n_y) n_x n_z \times n_s n_r \rbrack` or pair of traveltime tables of size :math:`\lbrack (n_y) n_x n_z \times n_s \rbrack` and :math:`\lbrack (n_y) n_x n_z \times n_r \rbrack` (to be provided if ``mode='byot'``). Note that the latter approach is recommended as less memory demanding - than the former. Moreover, only ``mode='dynamic'`` is only possible when traveltimes are provided in - the latter form. + than the former. Moreover, ``mode='dynamic'`` and ``engine='cuda'`` are only possible when traveltimes are + provided in the latter form. amp : :obj:`numpy.ndarray`, optional .. versionadded:: 2.0.0 @@ -109,7 +111,7 @@ class Kirchhoff(LinearOperator): Deprecated, will be removed in v3.0.0. Simply kept for back-compatibility with previous implementation, but effectively not affecting the behaviour of the operator. engine : :obj:`str`, optional - Engine used for computations (``numpy`` or ``numba``). + Engine used for computations (``numpy``, ``numba`` or ``cuda``). dtype : :obj:`str`, optional Type of elements in input array. name : :obj:`str`, optional @@ -999,7 +1001,6 @@ def _register_multiplications(self, engine: str) -> None: if engine not in ["numpy", "numba", "cuda"]: raise KeyError("engine must be numpy or numba or cuda") if engine == "numba" and jit_message is None: - # numba numba_opts = dict( nopython=True, nogil=True, parallel=parallel ) # fastmath=True, @@ -1014,16 +1015,33 @@ def _register_multiplications(self, engine: str) -> None: self._kirch_rmatvec = jit(**numba_opts)(self._trav_kirch_rmatvec) elif engine == "cuda": if self.dynamic and self.travsrcrec: - self.cuda_helper = _kirchhoffCudaHelper(self.ns, self.nr, self.nt, self.ni, 1, 1) - self.cuda_helper._data_prep_dynamic(self.ns, self.nr, self.nt, self.ni, self.nz, self.dt, - self.aperture, self.angleaperture, - self.aperturetap, self.vel, self.six, self.rix, self.trav_recs, - self.angle_recs, self.trav_srcs, self.angle_srcs,self.amp_srcs, - self.amp_recs) + self.cuda_helper = _KirchhoffCudaHelper( + self.ns, self.nr, self.nt, self.ni, 1 + ) + self.cuda_helper._data_prep_dynamic( + self.ns, + self.nr, + self.nt, + self.ni, + self.nz, + self.dt, + self.aperture, + self.angleaperture, + self.aperturetap, + self.vel, + self.six, + self.rix, + self.trav_recs, + self.angle_recs, + self.trav_srcs, + self.angle_srcs, + self.amp_srcs, + self.amp_recs, + ) elif self.travsrcrec: - self.cuda_helper = _kirchhoffCudaHelper(self.ns, self.nr, self.nt, self.ni, 0, 1) - elif not self.travsrcrec: - self.cuda_helper = _kirchhoffCudaHelper(self.ns, self.nr, self.nt, self.ni, 0, 0) + self.cuda_helper = _KirchhoffCudaHelper( + self.ns, self.nr, self.nt, self.ni, 0 + ) self._kirch_matvec = self.cuda_helper._matvec_call self._kirch_rmatvec = self.cuda_helper._rmatvec_call else: @@ -1041,7 +1059,8 @@ def _register_multiplications(self, engine: str) -> None: @reshaped def _matvec(self, x: NDArray) -> NDArray: - y = np.zeros((self.nsnr, self.nt), dtype=self.dtype) + ncp = get_array_module(x) + y = ncp.zeros((self.nsnr, self.nt), dtype=self.dtype) if self.dynamic and self.travsrcrec: inputs = ( x.ravel(), @@ -1088,9 +1107,10 @@ def _matvec(self, x: NDArray) -> NDArray: @reshaped def _rmatvec(self, x: NDArray) -> NDArray: + ncp = get_array_module(x) x = self.cop._rmatvec(x.ravel()) x = x.reshape(self.nsnr, self.nt) - y = np.zeros(self.ni, dtype=self.dtype) + y = ncp.zeros(self.ni, dtype=self.dtype) if self.dynamic and self.travsrcrec: inputs = ( x,