Skip to content

Commit

Permalink
perf(mpiarray): avoid redistribute copy on single process
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Oct 30, 2023
1 parent ea85ecb commit a2f8ba9
Showing 1 changed file with 58 additions and 63 deletions.
121 changes: 58 additions & 63 deletions caput/mpiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,10 @@ def redistribute(self, axis: int) -> "MPIArray":
if self.axis == axis or self.comm is None:
return self

if self.comm.size == 1:
self._axis = axis
return self

# Check to make sure there is enough memory to perform the redistribution.
# Must be able to allocate the target array and 2 buffers. We allocate
# slightly more space than needed to be safe
Expand All @@ -757,76 +761,67 @@ def redistribute(self, axis: int) -> "MPIArray":
# Get views into local and target arrays
arr = self.local_array
target_arr = dist_arr.local_array

# Avoid repeat mpi property calls
csize = self.comm.size
crank = self.comm.rank

if csize == 1:
if arr.shape[self.axis] == self.global_shape[self.axis]:
# We are working on a single node.
target_arr[:] = arr
else:
raise ValueError(
f"Global shape {self.global_shape} is incompatible with local "
f"array shape {self.shape}"
)
else:
# Get the start and end of each subrange of interest
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
# Split the soruce array into properly sized blocks for sending
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
# Create fixed-size contiguous buffers for sending and receiving
buffer_shape = list(target_arr.shape)
buffer_shape[self.axis] = max(ear - sar)
buffer_shape[axis] = max(eac - sac)
# Pre-allocate buffers and buffer type
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
send_buffer = np.empty_like(recv_buffer)
buf_type = self._prep_buf(send_buffer)[1]

# Empty slices for target, send buf, recv buf
targetsl = [slice(None)] * len(buffer_shape)
sendsl = [slice(None)] * len(buffer_shape)
recvsl = [slice(None)] * len(buffer_shape)
# Send and recv buf have some fixed axis slices per rank
sendsl[self.axis] = slice(ear[crank] - sar[crank])
recvsl[axis] = slice(eac[crank] - sac[crank])

mpistatus = mpiutil.MPI.Status()

# Cyclically pass messages forward to i adjacent rank
for i in range(csize):
send_to = (crank + i) % csize
recv_from = (crank - i) % csize

# Write send data into send buffer location
sendsl[axis] = slice(eac[send_to] - sac[send_to])
send_buffer[tuple(sendsl)] = blocks[send_to]

self.comm.Sendrecv(
sendbuf=(send_buffer, buf_type),
dest=send_to,
sendtag=(csize * crank + send_to),
recvbuf=(recv_buffer, buf_type),
source=recv_from,
recvtag=(csize * recv_from + crank),
status=mpistatus,
)
# Get the start and end of each subrange of interest
_, sac, eac = mpiutil.split_all(self.global_shape[axis], self.comm)
_, sar, ear = mpiutil.split_all(self.global_shape[self.axis], self.comm)
# Split the soruce array into properly sized blocks for sending
blocks = np.array_split(arr, np.insert(eac, 0, sac[0]), axis)[1:]
# Create fixed-size contiguous buffers for sending and receiving
buffer_shape = list(target_arr.shape)
buffer_shape[self.axis] = max(ear - sar)
buffer_shape[axis] = max(eac - sac)
# Pre-allocate buffers and buffer type
recv_buffer = np.empty(buffer_shape, dtype=self.dtype)
send_buffer = np.empty_like(recv_buffer)
buf_type = self._prep_buf(send_buffer)[1]

# Empty slices for target, send buf, recv buf
targetsl = [slice(None)] * len(buffer_shape)
sendsl = [slice(None)] * len(buffer_shape)
recvsl = [slice(None)] * len(buffer_shape)
# Send and recv buf have some fixed axis slices per rank
sendsl[self.axis] = slice(ear[crank] - sar[crank])
recvsl[axis] = slice(eac[crank] - sac[crank])

mpistatus = mpiutil.MPI.Status()

# Cyclically pass messages forward to i adjacent rank
for i in range(csize):
send_to = (crank + i) % csize
recv_from = (crank - i) % csize

# Write send data into send buffer location
sendsl[axis] = slice(eac[send_to] - sac[send_to])
send_buffer[tuple(sendsl)] = blocks[send_to]

self.comm.Sendrecv(
sendbuf=(send_buffer, buf_type),
dest=send_to,
sendtag=(csize * crank + send_to),
recvbuf=(recv_buffer, buf_type),
source=recv_from,
recvtag=(csize * recv_from + crank),
status=mpistatus,
)

if mpistatus.error != mpiutil.MPI.SUCCESS:
logger.error(
f"**** ERROR in MPI SEND/RECV "
f"(rank={crank}, "
f"target={send_to}, "
f"receive={recv_from}) ****"
)
if mpistatus.error != mpiutil.MPI.SUCCESS:
logger.error(
f"**** ERROR in MPI SEND/RECV "
f"(rank={crank}, "
f"target={send_to}, "
f"receive={recv_from}) ****"
)

# Write buffer into target location
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])
# Write buffer into target location
targetsl[self.axis] = slice(sar[recv_from], ear[recv_from])
recvsl[self.axis] = slice(ear[recv_from] - sar[recv_from])

target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]
target_arr[tuple(targetsl)] = recv_buffer[tuple(recvsl)]

return dist_arr

Expand Down

0 comments on commit a2f8ba9

Please sign in to comment.