forked from jrs65/pfb-inverse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpfbdriver.py
79 lines (58 loc) · 2.06 KB
/
pfbdriver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import h5py
import numpy as np
import mpiutil
import pfb
parser = argparse.ArgumentParser(description="Invert the PFB.")
parser.add_argument("file_in", help="Input file to process.")
parser.add_argument("file_out", help="File to write output timestream into.")
parser.add_argument(
"-f", type=int, default=1024, help="Number of frequencies in file.", dest="nfreq"
)
parser.add_argument(
"-n", type=int, default=4, help="Number of taps used for PFB", dest="ntap"
)
parser.add_argument(
"-m",
action="store_true",
help="Input file is missing Nyquist frequency.",
dest="no_nyquist",
)
args = parser.parse_args()
print(args)
# =========================
#
# This is where we must load in the data. At the moment we have a very stupid method.
#
# At the end of it, each process must end up with a section of the file from
# start_block to end_block, and these must be as doubles.
#
# =========================
pfb_data = np.load(args.file_in).reshape(-1, args.nfreq) # Load in whole file!
nblock = pfb_data.shape[0] # Find out the file length in blocks
local_block, start_block, end_block = mpiutil.split_local(
nblock
) # Work out how to split up the file into sections
pfb_local = pfb_data[start_block:end_block] # Pull out the correct section of the file.
# =========================
# Apply inverse PFB
rects = pfb.inverse_pfb_parallel(
pfb_local, args.ntap, nblock, no_nyquist=args.no_nyquist
)
# Calculate the range of local timestream blocks
local_tsblock, start_tsblock, end_tsblock = mpiutil.split_local(nblock)
# =========================
#
# This is where we must output the data. Again this is a very stupid way of doing it.
#
# =========================
for ri in range(mpiutil.size):
if mpiutil.rank == ri:
with h5py.File(args.file_out, "a") as f:
if "ts" not in f:
f.create_dataset(
"ts", shape=(nblock, rects.shape[-1]), dtype=rects.dtype
)
f["ts"][start_tsblock:end_tsblock] = rects
mpiutil.barrier()
# =========================