Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Njit parallelization #16

Merged
merged 14 commits into from
Nov 9, 2023
70 changes: 39 additions & 31 deletions fdmt/cpu_fdmt.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
#!/usr/bin/env python3

import sys
import numpy as np
from time import time
from attr import attrs, attrib, cmp_using
#from . import fdmt_iter_par
from .fdmt_njit import fdmt_iter_par

import numpy as np
from attr import attrib, attrs, cmp_using

from fdmt.fdmt_njit import fdmt_iter_par


@attrs
class FDMT:
'''
Collection of attributes and helper arrays necessary for dispersion using
FDMT.
"""
Collection of attributes and helper arrays necessary for dispersion using FDMT.

Parameters
---------
Expand All @@ -25,7 +26,8 @@ class FDMT:
top and bottom of the band (defines the maximum DM of the search)
num_threads: int
Number of threads to use for parallelization. If 1, no parallelization
'''
"""

fmin: float = attrib(default=400.1953125)
fmax: float = attrib(default=800.1953125)
nchan: int = attrib(default=1024)
Expand All @@ -37,10 +39,10 @@ class FDMT:
fs: np.ndarray = attrib(init=False, eq=cmp_using(eq=np.array_equal))
Q: list = attrib(init=False, eq=cmp_using(eq=np.array_equal))


def __attrs_post_init__(self):
self.fs, self.df = fs,df = np.linspace(self.fmin, self.fmax, self.nchan, endpoint=False,retstep=True)

self.fs, self.df = fs, df = np.linspace(
self.fmin, self.fmax, self.nchan, endpoint=False, retstep=True
)

def subDT(self, f, dF=None):
"Get needed DT of subband to yield maxDT over entire band"
Expand All @@ -50,32 +52,30 @@ def subDT(self, f, dF=None):
glo = self.fmin**-2 - self.fmax**-2
return np.ceil((self.maxDT - 1) * loc / glo).astype(int) + 1


def buildAB(self, numCols, dtype=np.uint32):
numRowsA = (self.subDT(self.fs)).sum()
numRowsB = (self.subDT(self.fs[::2], self.fs[2] - self.fs[0])).sum()
self.A = np.zeros([numRowsA, numCols], dtype)
self.B = np.zeros([numRowsB, numCols], dtype)


def buildQ(self):
self.Q = []
for i in range(int(np.log2(self.nchan)) + 1):
needed = self.subDT(self.fs[:: 2**i], self.df * 2**i)
self.Q.append(np.cumsum(needed) - needed)

if i == 0:
self.Q = np.zeros(
(int(np.log2(self.nchan)) + 1, len(needed)), dtype="int32"
)
self.Q[i, : len(needed)] = np.cumsum(needed) - needed

def prep(self, cols, dtype=np.uint32):
"Prepares necessary matrices for FDMT"
self.buildAB(cols, dtype=dtype)
self.buildQ()

def fdmt(self, I, retDMT=False, verbose=False, padding=False, frontpadding=True):
"""
Computes DM Transform. If retDMT returns transform, else returns max sigma.

def fdmt(self, I, retDMT=False, verbose=False, padding=False,
frontpadding=True):
"""Computes DM Transform. If retDMT returns transform, else returns max
sigma.

Parameters
==========
I: np.ndarray
Expand Down Expand Up @@ -129,7 +129,12 @@ def fdmt(self, I, retDMT=False, verbose=False, padding=False,

I = np.concatenate(concat_tuple, axis=1)

if self.A is None or self.A.shape[1] != I.shape[1] or self.A.dtype != I.dtype or True:
if (
self.A is None
or self.A.shape[1] != I.shape[1]
or self.A.dtype != I.dtype
or True
):
self.prep(I.shape[1], dtype=I.dtype)

t1 = time()
Expand All @@ -146,21 +151,20 @@ def fdmt(self, I, retDMT=False, verbose=False, padding=False,
print("Iterating time: %.2f s" % (t3 - t2))
print("Total time: %.2f s" % (t3 - t1))

DMT = dest[:self.maxDT]
DMT = dest[: self.maxDT]

if retDMT:
# We need to cut off the first maxDT samples either way
# because now frontpadding works by inserting maxDT samples' worth
# of zeros at the front of I
return DMT[:, self.maxDT:]
return DMT[:, self.maxDT :]
noiseRMS = np.array([DMT[i, i:].std() for i in range(self.maxDT)])
noiseMean = np.array([DMT[i, i:].mean() for i in range(self.maxDT)])
sigmi = (DMT.T - noiseMean) / noiseRMS
if verbose:
print("Maximum sigma value: %.3f" % sigmi.max())
return sigmi.max()


def fdmt_initialize(self, I):
self.A[self.Q[0], :] = I
chDTs = self.subDT(self.fs)
Expand All @@ -169,7 +173,9 @@ def fdmt_initialize(self, I):
DTsteps = list(np.where(chDTs[:-1] - chDTs[1:] != 0)[0])
DTplan = commonDTs + DTsteps[::-1]
for i, t in enumerate(DTplan, 1):
self.A[self.Q[0][:t] + i, i:] = self.A[self.Q[0][:t] + i - 1, i:] + I[:t, :-i]
self.A[self.Q[0][:t] + i, i:] = (
self.A[self.Q[0][:t] + i - 1, i:] + I[:t, :-i]
)
for i, t in enumerate(DTplan, 1):
# A[Q[0][:t]+i,i:] /= int(i+1)
self.A[self.Q[0][:t] + i, i:] /= int(i + 1)
Expand All @@ -185,17 +191,19 @@ def fdmt_iteration(self, src, dest, i):
maxDT = self.maxDT
num_threads = self.num_threads

fdmt_iter_par(fs, nchan, df, Q, src, dest, i, fmin, fmax, np.float32(maxDT), num_threads)

fdmt_iter_par(
fs, nchan, df, Q, src, dest, i, fmin, fmax, np.float32(maxDT), num_threads
)

def reset_ABQ(self):
self.A = None
self.B = None
self.Q = []

def recursive_fdmt(self, I, depth=0, curMax=0):
"""Performs FDMT, downsamples and repeats recursively, returning max sigma
I should have shape (nchan, nsamp) where nsamp is the number of time samples"""
"""Performs FDMT, downsamples and repeats recursively, returning max sigma I
should have shape (nchan, nsamp) where nsamp is the number of time samples.
"""
curMax = max(curMax, fdmt(I))
if depth <= 0:
return curMax
Expand Down
15 changes: 8 additions & 7 deletions fdmt/fdmt_njit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from numba import njit, set_num_threads, prange
from numba import njit, prange, set_num_threads


@njit(parallel=True)
def fdmt_iter_par(fs, nchan, df, Q, src, dest, i, fmin, fmax, maxDT, num_threads):
"""
Perform a single iteration of the Fast Dispersion Measure Transform (FDMT) algorithm,
parallelized using numba.
Perform a single iteration of the Fast Dispersion Measure Transform (FDMT)
algorithm, parallelized using numba.

Parameters:
fs (ndarray): Array of center frequencies for each channel.
Expand Down Expand Up @@ -41,16 +42,16 @@ def fdmt_iter_par(fs, nchan, df, Q, src, dest, i, fmin, fmax, maxDT, num_threads
C01 = ((f1 - cor) ** -2 - f0**-2) / (f2**-2 - f0**-2)
C12 = ((f1 + cor) ** -2 - f0**-2) / (f2**-2 - f0**-2)

#SDT
# SDT
loc = f0**-2 - (f0 + dF) ** -2
glo = fmin**-2 - fmax**-2
R = int((maxDT- 1) * loc / glo) + 2
R = int((maxDT - 1) * loc / glo) + 2

for i_dT in prange(0, R):

dT_mid01 = round(i_dT * C01)
dT_mid12 = round(i_dT * C12)
dT_rest = i_dT - dT_mid12
dest[Q[i][i_F] + i_dT, :] = src[Q[i - 1][2 * i_F] + dT_mid01, :]
dest[Q[i][i_F] + i_dT, dT_mid12:] += src[
Q[i - 1][2 * i_F + 1] + dT_rest, : T - dT_mid12]
Q[i - 1][2 * i_F + 1] + dT_rest, : T - dT_mid12
]
Loading