Skip to content

Commit

Permalink
Njit parallelization (#16)
Browse files Browse the repository at this point in the history
* first working njit, improvement needed

* first working version, 2x improvement for 8192 chan

* generalize num_threads

* docstring for fdmt_iter_par

* slight change to docstring

* poetry install setup

* feat(pyproject.toml): Adding Poetry to project

* Fill Q into empty array for njit

* build Q as array

* Build Q as array of ints, gives same result and avoids numba warnings

* pre-commit formatting

---------

Co-authored-by: chrisfandrade16 <[email protected]>
  • Loading branch information
ramain and chrisfandrade16 authored Nov 9, 2023
1 parent 4331abb commit ee94df7
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 38 deletions.
70 changes: 39 additions & 31 deletions fdmt/cpu_fdmt.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
#!/usr/bin/env python3

import sys
import numpy as np
from time import time
from attr import attrs, attrib, cmp_using
#from . import fdmt_iter_par
from .fdmt_njit import fdmt_iter_par

import numpy as np
from attr import attrib, attrs, cmp_using

from fdmt.fdmt_njit import fdmt_iter_par


@attrs
class FDMT:
'''
Collection of attributes and helper arrays necessary for dispersion using
FDMT.
"""
Collection of attributes and helper arrays necessary for dispersion using FDMT.
Parameters
---------
Expand All @@ -25,7 +26,8 @@ class FDMT:
top and bottom of the band (defines the maximum DM of the search)
num_threads: int
Number of threads to use for parallelization. If 1, no parallelization
'''
"""

fmin: float = attrib(default=400.1953125)
fmax: float = attrib(default=800.1953125)
nchan: int = attrib(default=1024)
Expand All @@ -37,10 +39,10 @@ class FDMT:
fs: np.ndarray = attrib(init=False, eq=cmp_using(eq=np.array_equal))
Q: list = attrib(init=False, eq=cmp_using(eq=np.array_equal))


def __attrs_post_init__(self):
self.fs, self.df = fs,df = np.linspace(self.fmin, self.fmax, self.nchan, endpoint=False,retstep=True)

self.fs, self.df = fs, df = np.linspace(
self.fmin, self.fmax, self.nchan, endpoint=False, retstep=True
)

def subDT(self, f, dF=None):
"Get needed DT of subband to yield maxDT over entire band"
Expand All @@ -50,32 +52,30 @@ def subDT(self, f, dF=None):
glo = self.fmin**-2 - self.fmax**-2
return np.ceil((self.maxDT - 1) * loc / glo).astype(int) + 1


def buildAB(self, numCols, dtype=np.uint32):
numRowsA = (self.subDT(self.fs)).sum()
numRowsB = (self.subDT(self.fs[::2], self.fs[2] - self.fs[0])).sum()
self.A = np.zeros([numRowsA, numCols], dtype)
self.B = np.zeros([numRowsB, numCols], dtype)


def buildQ(self):
self.Q = []
for i in range(int(np.log2(self.nchan)) + 1):
needed = self.subDT(self.fs[:: 2**i], self.df * 2**i)
self.Q.append(np.cumsum(needed) - needed)

if i == 0:
self.Q = np.zeros(
(int(np.log2(self.nchan)) + 1, len(needed)), dtype="int32"
)
self.Q[i, : len(needed)] = np.cumsum(needed) - needed

def prep(self, cols, dtype=np.uint32):
"Prepares necessary matrices for FDMT"
self.buildAB(cols, dtype=dtype)
self.buildQ()

def fdmt(self, I, retDMT=False, verbose=False, padding=False, frontpadding=True):
"""
Computes DM Transform. If retDMT returns transform, else returns max sigma.
def fdmt(self, I, retDMT=False, verbose=False, padding=False,
frontpadding=True):
"""Computes DM Transform. If retDMT returns transform, else returns max
sigma.
Parameters
==========
I: np.ndarray
Expand Down Expand Up @@ -129,7 +129,12 @@ def fdmt(self, I, retDMT=False, verbose=False, padding=False,

I = np.concatenate(concat_tuple, axis=1)

if self.A is None or self.A.shape[1] != I.shape[1] or self.A.dtype != I.dtype or True:
if (
self.A is None
or self.A.shape[1] != I.shape[1]
or self.A.dtype != I.dtype
or True
):
self.prep(I.shape[1], dtype=I.dtype)

t1 = time()
Expand All @@ -146,21 +151,20 @@ def fdmt(self, I, retDMT=False, verbose=False, padding=False,
print("Iterating time: %.2f s" % (t3 - t2))
print("Total time: %.2f s" % (t3 - t1))

DMT = dest[:self.maxDT]
DMT = dest[: self.maxDT]

if retDMT:
# We need to cut off the first maxDT samples either way
# because now frontpadding works by inserting maxDT samples' worth
# of zeros at the front of I
return DMT[:, self.maxDT:]
return DMT[:, self.maxDT :]
noiseRMS = np.array([DMT[i, i:].std() for i in range(self.maxDT)])
noiseMean = np.array([DMT[i, i:].mean() for i in range(self.maxDT)])
sigmi = (DMT.T - noiseMean) / noiseRMS
if verbose:
print("Maximum sigma value: %.3f" % sigmi.max())
return sigmi.max()


def fdmt_initialize(self, I):
self.A[self.Q[0], :] = I
chDTs = self.subDT(self.fs)
Expand All @@ -169,7 +173,9 @@ def fdmt_initialize(self, I):
DTsteps = list(np.where(chDTs[:-1] - chDTs[1:] != 0)[0])
DTplan = commonDTs + DTsteps[::-1]
for i, t in enumerate(DTplan, 1):
self.A[self.Q[0][:t] + i, i:] = self.A[self.Q[0][:t] + i - 1, i:] + I[:t, :-i]
self.A[self.Q[0][:t] + i, i:] = (
self.A[self.Q[0][:t] + i - 1, i:] + I[:t, :-i]
)
for i, t in enumerate(DTplan, 1):
# A[Q[0][:t]+i,i:] /= int(i+1)
self.A[self.Q[0][:t] + i, i:] /= int(i + 1)
Expand All @@ -185,17 +191,19 @@ def fdmt_iteration(self, src, dest, i):
maxDT = self.maxDT
num_threads = self.num_threads

fdmt_iter_par(fs, nchan, df, Q, src, dest, i, fmin, fmax, np.float32(maxDT), num_threads)

fdmt_iter_par(
fs, nchan, df, Q, src, dest, i, fmin, fmax, np.float32(maxDT), num_threads
)

def reset_ABQ(self):
self.A = None
self.B = None
self.Q = []

def recursive_fdmt(self, I, depth=0, curMax=0):
"""Performs FDMT, downsamples and repeats recursively, returning max sigma
I should have shape (nchan, nsamp) where nsamp is the number of time samples"""
"""Performs FDMT, downsamples and repeats recursively, returning max sigma I
should have shape (nchan, nsamp) where nsamp is the number of time samples.
"""
curMax = max(curMax, fdmt(I))
if depth <= 0:
return curMax
Expand Down
15 changes: 8 additions & 7 deletions fdmt/fdmt_njit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from numba import njit, set_num_threads, prange
from numba import njit, prange, set_num_threads


@njit(parallel=True)
def fdmt_iter_par(fs, nchan, df, Q, src, dest, i, fmin, fmax, maxDT, num_threads):
"""
Perform a single iteration of the Fast Dispersion Measure Transform (FDMT) algorithm,
parallelized using numba.
Perform a single iteration of the Fast Dispersion Measure Transform (FDMT)
algorithm, parallelized using numba.
Parameters:
fs (ndarray): Array of center frequencies for each channel.
Expand Down Expand Up @@ -41,16 +42,16 @@ def fdmt_iter_par(fs, nchan, df, Q, src, dest, i, fmin, fmax, maxDT, num_threads
C01 = ((f1 - cor) ** -2 - f0**-2) / (f2**-2 - f0**-2)
C12 = ((f1 + cor) ** -2 - f0**-2) / (f2**-2 - f0**-2)

#SDT
# SDT
loc = f0**-2 - (f0 + dF) ** -2
glo = fmin**-2 - fmax**-2
R = int((maxDT- 1) * loc / glo) + 2
R = int((maxDT - 1) * loc / glo) + 2

for i_dT in prange(0, R):

dT_mid01 = round(i_dT * C01)
dT_mid12 = round(i_dT * C12)
dT_rest = i_dT - dT_mid12
dest[Q[i][i_F] + i_dT, :] = src[Q[i - 1][2 * i_F] + dT_mid01, :]
dest[Q[i][i_F] + i_dT, dT_mid12:] += src[
Q[i - 1][2 * i_F + 1] + dT_rest, : T - dT_mid12]
Q[i - 1][2 * i_F + 1] + dT_rest, : T - dT_mid12
]

0 comments on commit ee94df7

Please sign in to comment.