diff --git a/caput/mpiarray.py b/caput/mpiarray.py index c6264bc1..f74c5d04 100644 --- a/caput/mpiarray.py +++ b/caput/mpiarray.py @@ -734,6 +734,9 @@ def redistribute(self, axis: int) -> "MPIArray": if self.axis == axis or self.comm is None: return self + if self.comm.size == 1: + return MPIArray.wrap(self.local_array, axis, self.comm) + # Check to make sure there is enough memory to perform the redistribution. # Must be able to allocate the target array and 2 buffers. We allocate # slightly more space than needed to be safe @@ -756,72 +759,62 @@ def redistribute(self, axis: int) -> "MPIArray": csize = self.comm.size crank = self.comm.rank - if csize == 1: - if arr.shape[self.axis] == self.global_shape[self.axis]: - # We are working on a single node. - target_arr[:] = arr - else: - raise ValueError( - f"Global shape {self.global_shape} is incompatible with local " - f"array shape {self.shape}" - ) - else: - # Get the start and end of each subrange of interest - _, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm) - _, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm) - # Split the soruce array into properly sized blocks for sending - blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:] - # Create fixed-size contiguous buffers for sending and receiving - buffer_shape = list(target_arr.shape) - buffer_shape[self.axis] = max(ear - sar) - buffer_shape[axis] = max(eac - sac) - # Pre-allocate buffers and buffer type - recv_buffer = np.empty(buffer_shape, dtype=self.dtype) - send_buffer = np.empty_like(recv_buffer) - buf_type = self._prep_buf(send_buffer)[1] - - # Empty slices for target, send buf, recv buf - targetsl = [slice(None)] * len(buffer_shape) - sendsl = [slice(None)] * len(buffer_shape) - recvsl = [slice(None)] * len(buffer_shape) - # Send and recv buf have some fixed axis slices per rank - sendsl[self.axis] = slice(ear[crank] - sar[crank]) - recvsl[axis] = slice(eac[crank] - sac[crank]) - - mpistatus = mpiutil.MPI.Status() - - # Cyclically pass messages forward to i adjacent rank - for i in range(csize): - send_to = (crank + i) % csize - recv_from = (crank - i) % csize - - # Write send data into send buffer location - sendsl[axis] = slice(eac[send_to] - sac[send_to]) - send_buffer[tuple(sendsl)] = blocks[send_to] - - self.comm.Sendrecv( - sendbuf=(send_buffer, buf_type), - dest=send_to, - sendtag=(csize * crank + send_to), - recvbuf=(recv_buffer, buf_type), - source=recv_from, - recvtag=(csize * recv_from + crank), - status=mpistatus, - ) + # Get the start and end of each subrange of interest + _, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm) + _, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm) + # Split the soruce array into properly sized blocks for sending + blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:] + # Create fixed-size contiguous buffers for sending and receiving + buffer_shape = list(target_arr.shape) + buffer_shape[self.axis] = max(ear - sar) + buffer_shape[axis] = max(eac - sac) + # Pre-allocate buffers and buffer type + recv_buffer = np.empty(buffer_shape, dtype=self.dtype) + send_buffer = np.empty_like(recv_buffer) + buf_type = self._prep_buf(send_buffer)[1] + + # Empty slices for target, send buf, recv buf + targetsl = [slice(None)] * len(buffer_shape) + sendsl = [slice(None)] * len(buffer_shape) + recvsl = [slice(None)] * len(buffer_shape) + # Send and recv buf have some fixed axis slices per rank + sendsl[self.axis] = slice(ear[crank] - sar[crank]) + recvsl[axis] = slice(eac[crank] - sac[crank]) + + mpistatus = mpiutil.MPI.Status() + + # Cyclically pass and receive array chunks across ranks + for i in range(csize): + send_to = (crank + i) % csize + recv_from = (crank - i) % csize + + # Write send data into send buffer location + sendsl[axis] = slice(eac[send_to] - sac[send_to]) + send_buffer[tuple(sendsl)] = blocks[send_to] + + self.comm.Sendrecv( + sendbuf=(send_buffer, buf_type), + dest=send_to, + sendtag=(csize * crank + send_to), + recvbuf=(recv_buffer, buf_type), + source=recv_from, + recvtag=(csize * recv_from + crank), + status=mpistatus, + ) - if mpistatus.error != mpiutil.MPI.SUCCESS: - logger.error( - f"**** ERROR in MPI SEND/RECV " - f"(rank={crank}, " - f"target={send_to}, " - f"receive={recv_from}) ****" - ) + if mpistatus.error != mpiutil.MPI.SUCCESS: + logger.error( + f"**** ERROR in MPI SEND/RECV " + f"(rank={crank}, " + f"target={send_to}, " + f"receive={recv_from}) ****" + ) - # Write buffer into target location - targetsl[self.axis] = slice(sar[recv_from], ear[recv_from]) - recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from]) + # Write buffer into target location + targetsl[self.axis] = slice(sar[recv_from], ear[recv_from]) + recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from]) - target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)] + target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)] return dist_arr