Skip to content

Commit

Permalink
feat(lss): make lss tasks mpi parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Sep 26, 2024
1 parent db38363 commit b735fa9
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 93 deletions.
36 changes: 26 additions & 10 deletions cora/core/skysim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scipy.integrate as si
import healpy

from caput import mpiarray
from cora.util import hputil, nputil


Expand Down Expand Up @@ -92,30 +93,45 @@ def mkfullsky(corr, nside, alms=False, rng=None):
"""

numz = corr.shape[1]
maxl = corr.shape[0] - 1

isdistributed = isinstance(corr, mpiarray.MPIArray)

if isdistributed:
maxl = corr.global_shape[0] - 1
corr = corr.local_array
else:
maxl = corr.shape[0] - 1

if corr.shape[2] != numz:
raise Exception("Correlation matrix is incorrect shape.")

alm_array = np.zeros((numz, 1, maxl + 1, maxl + 1), dtype=np.complex128)
alm_array = mpiarray.zeros(
(numz, 1, maxl + 1, maxl + 1), dtype=np.complex128, axis=2
)

# Generate gaussian deviates and transform to have correct correlation
# structure
for l in range(maxl + 1):
for lloc, lglob in alm_array.enumerate(axis=2):
# Add in a small diagonal to try and ensure positive definiteness
cmax = corr[l].diagonal().max() * 1e-14
corrm = corr[l] + np.identity(numz) * cmax
cmax = corr[lloc].diagonal().max() * 1e-14
corrm = corr[lloc] + np.identity(numz) * cmax

trans = nputil.matrix_root_manynull(corrm, truncate=False)
gaussvars = nputil.complex_std_normal((numz, l + 1), rng=rng)
alm_array[:, 0, l, : (l + 1)] = np.dot(trans, gaussvars)
gaussvars = nputil.complex_std_normal((numz, lglob + 1), rng=rng)
alm_array.local_array[:, 0, lloc, : (lglob + 1)] = np.dot(trans, gaussvars)

if alms:
return alm_array
# Return the entire alm array on each rank
return alm_array.allgather()

# Perform the spherical harmonic transform for each z
sky = hputil.sphtrans_inv_sky(alm_array, nside)
sky = sky[:, 0]
alm_array = alm_array.redistribute(axis=0)

sky = hputil.sphtrans_inv_sky(alm_array.local_array, nside)[:, 0]

if isdistributed:
# Re-wrap the final array
sky = mpiarray.MPIArray.wrap(sky, axis=0)

return sky

Expand Down
36 changes: 22 additions & 14 deletions cora/signal/corrfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import scipy.special as ss
from scipy.fftpack import dct

from caput import mpiarray

import hankl
import hankel
import pyfftlog
Expand Down Expand Up @@ -329,9 +331,6 @@ def corr_to_clarray(
M = q * lmax
mu, w, wsum = ss.roots_legendre(M, mu=True)

xlen = xarray.size
corr_array = np.zeros((M, xlen, xlen))

# If xromb > 0 we need to integrate over the radial bin width, start by modifying
# the array of distances to add extra points over which we'll integrate
#
Expand Down Expand Up @@ -359,12 +358,15 @@ def corr_to_clarray(
else:
xa = xarray

# Split the set of theta values to fill into groups of length ~50 and process each,
# this is helps reduce memory usage which otherwise we be massively inflated by the
# extra points we integrate over
for msec in np.array_split(np.arange(M), M // 50):

rc = coord.cosine_rule(mu[msec], xa, xa)
xlen = xarray.size
corr_array = mpiarray.zeros((M, xlen, xlen), axis=0)
clo = corr_array.local_offset[0]
_len = corr_array.local_array.shape[0]

# Split thetas into ~length 50 chunks, otherwise memory will blow up
for msec in np.array_split(np.arange(_len), _len // 50):
# Index into the global index in mu
rc = coord.cosine_rule(mu[clo + msec], xa, xa)
corr1 = corr(rc)

# If xromb then we need to integrate over the redshift bins which we do using
Expand All @@ -374,15 +376,21 @@ def corr_to_clarray(
corr1 = np.matmul(corr1, x_w).reshape(-1, xlen, xint, xlen)
corr1 = np.matmul(corr1.transpose(0, 1, 3, 2), x_w)

corr_array[msec, :, :] = corr1
corr_array.local_array[msec] = corr1

# Perform the dot product split over ranks for
# memory and time
lm = legendre_array(lmax, mu)
lm *= w * 4.0 * np.pi / wsum
lm *= w[np.newaxis] * 4.0 * np.pi / wsum

# Reshape and properly distribute the array to
# perform the dot product
corr_array = corr_array.reshape(None, -1).redistribute(axis=1)

clxx = np.dot(lm, corr_array.reshape(M, -1))
clxx = clxx.reshape(lmax + 1, xlen, xlen)
clxx = np.dot(lm, corr_array.local_array)
clxx = mpiarray.MPIArray.wrap(clxx, axis=1).redistribute(axis=0)

return clxx
return clxx.reshape(None, xlen, xlen)


def ps_to_aps_flat(
Expand Down
Loading

0 comments on commit b735fa9

Please sign in to comment.