Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize the LSS pipeline #61

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions cora/core/skysim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scipy.integrate as si
import healpy

from caput import mpiarray
from cora.util import hputil, nputil


Expand Down Expand Up @@ -92,30 +93,45 @@ def mkfullsky(corr, nside, alms=False, rng=None):
"""

numz = corr.shape[1]
maxl = corr.shape[0] - 1

isdistributed = isinstance(corr, mpiarray.MPIArray)

if isdistributed:
maxl = corr.global_shape[0] - 1
corr = corr.local_array
else:
maxl = corr.shape[0] - 1

if corr.shape[2] != numz:
raise Exception("Correlation matrix is incorrect shape.")

alm_array = np.zeros((numz, 1, maxl + 1, maxl + 1), dtype=np.complex128)
alm_array = mpiarray.zeros(
(numz, 1, maxl + 1, maxl + 1), dtype=np.complex128, axis=2
)

# Generate gaussian deviates and transform to have correct correlation
# structure
for l in range(maxl + 1):
for lloc, lglob in alm_array.enumerate(axis=2):
# Add in a small diagonal to try and ensure positive definiteness
cmax = corr[l].diagonal().max() * 1e-14
corrm = corr[l] + np.identity(numz) * cmax
cmax = corr[lloc].diagonal().max() * 1e-14
corrm = corr[lloc] + np.identity(numz) * cmax

trans = nputil.matrix_root_manynull(corrm, truncate=False)
gaussvars = nputil.complex_std_normal((numz, l + 1), rng=rng)
alm_array[:, 0, l, : (l + 1)] = np.dot(trans, gaussvars)
gaussvars = nputil.complex_std_normal((numz, lglob + 1), rng=rng)
alm_array.local_array[:, 0, lloc, : (lglob + 1)] = np.dot(trans, gaussvars)

if alms:
return alm_array
# Return the entire alm array on each rank
return alm_array.allgather()

# Perform the spherical harmonic transform for each z
sky = hputil.sphtrans_inv_sky(alm_array, nside)
sky = sky[:, 0]
alm_array = alm_array.redistribute(axis=0)

sky = hputil.sphtrans_inv_sky(alm_array.local_array, nside)[:, 0]

if isdistributed:
# Re-wrap the final array
sky = mpiarray.MPIArray.wrap(sky, axis=0)

return sky

Expand Down
36 changes: 22 additions & 14 deletions cora/signal/corrfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import scipy.special as ss
from scipy.fftpack import dct

from caput import mpiarray

import hankl
import hankel
import pyfftlog
Expand Down Expand Up @@ -329,9 +331,6 @@ def corr_to_clarray(
M = q * lmax
mu, w, wsum = ss.roots_legendre(M, mu=True)

xlen = xarray.size
corr_array = np.zeros((M, xlen, xlen))

# If xromb > 0 we need to integrate over the radial bin width, start by modifying
# the array of distances to add extra points over which we'll integrate
#
Expand Down Expand Up @@ -359,12 +358,15 @@ def corr_to_clarray(
else:
xa = xarray

# Split the set of theta values to fill into groups of length ~50 and process each,
# this is helps reduce memory usage which otherwise we be massively inflated by the
# extra points we integrate over
for msec in np.array_split(np.arange(M), M // 50):

rc = coord.cosine_rule(mu[msec], xa, xa)
xlen = xarray.size
corr_array = mpiarray.zeros((M, xlen, xlen), axis=0)
clo = corr_array.local_offset[0]
_len = corr_array.local_array.shape[0]

# Split thetas into ~length 50 chunks, otherwise memory will blow up
for msec in np.array_split(np.arange(_len), _len // 50):
# Index into the global index in mu
rc = coord.cosine_rule(mu[clo + msec], xa, xa)
corr1 = corr(rc)

# If xromb then we need to integrate over the redshift bins which we do using
Expand All @@ -374,15 +376,21 @@ def corr_to_clarray(
corr1 = np.matmul(corr1, x_w).reshape(-1, xlen, xint, xlen)
corr1 = np.matmul(corr1.transpose(0, 1, 3, 2), x_w)

corr_array[msec, :, :] = corr1
corr_array.local_array[msec] = corr1

# Perform the dot product split over ranks for
# memory and time
lm = legendre_array(lmax, mu)
lm *= w * 4.0 * np.pi / wsum
lm *= w[np.newaxis] * 4.0 * np.pi / wsum

# Reshape and properly distribute the array to
# perform the dot product
corr_array = corr_array.reshape(None, -1).redistribute(axis=1)

clxx = np.dot(lm, corr_array.reshape(M, -1))
clxx = clxx.reshape(lmax + 1, xlen, xlen)
clxx = np.dot(lm, corr_array.local_array)
clxx = mpiarray.MPIArray.wrap(clxx, axis=1).redistribute(axis=0)

return clxx
return clxx.reshape(None, xlen, xlen)


def ps_to_aps_flat(
Expand Down
Loading