Skip to content

Commit

Permalink
Merge pull request #70 from CHIMEFRB/chimefrb-updates
Browse files Browse the repository at this point in the history
Chimefrb updates
  • Loading branch information
emmanuelfonseca authored Jul 26, 2023
2 parents 17a412d + e27838d commit 18f2459
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 8 deletions.
2 changes: 1 addition & 1 deletion fitburst/analysis/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def fit(self, exact_jacobian: bool = True) -> None:

except Exception as exc:
print("ERROR: solver encountered a failure! Debug!")
print(sys.exc_info()[2])
print(sys.exc_info())

def fix_parameter(self, parameter_list: list) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion fitburst/analysis/peak_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_parameters_dict(self, original_dict: dict, update_width=False):
burst_parameters[current_key] = (self.burst_widths / 1000.).tolist()

else:
burst_parameters[current_key] = original_dict[current_key] * mul_factor
burst_parameters[current_key] = [original_dict[current_key][0]] * mul_factor

#burst_parameters={
# "amplitude" : original_dict*mul_factor,
Expand Down
52 changes: 48 additions & 4 deletions fitburst/pipelines/fitburst_example_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from fitburst.analysis.model import SpectrumModeler
from fitburst.analysis.peak_finder import FindPeak
from fitburst.backend.generic import DataReader
from fitburst.analysis.fitter import LSFitter
from fitburst.analysis.model import SpectrumModeler
from copy import deepcopy
import fitburst.routines.manipulate as manip
import fitburst.utilities as ut
Expand Down Expand Up @@ -140,6 +141,23 @@
"filenamese based on input filenames."
)

parser.add_argument(
"--peakfind_dist",
action="store",
dest="peakfind_dist", default=5,
type=int,
help="Separation used for peak-finding algorithm (for multi-component fitting)."
)

parser.add_argument(
"--peakfind_rms",
action="store",
dest="peakfind_rms",
default=None,
type=float,
help="RMS used for peak-finding algorithm (for multi-component fitting)."
)

parser.add_argument(
"--preprocess",
action="store_true",
Expand Down Expand Up @@ -274,6 +292,8 @@
num_iterations = args.num_iterations
parameters_to_fit = args.parameters_to_fit
parameters_to_fix = args.parameters_to_fix
peakfind_rms = args.peakfind_rms
peakfind_dist = args.peakfind_dist
preprocess_data = args.preprocess_data
remove_dispersion_smearing = args.remove_dispersion_smearing
use_outfile_substring = args.use_outfile_substring
Expand Down Expand Up @@ -322,8 +342,15 @@
data.good_freq = np.sum(data.data_weights, axis=1) != 0.
data.good_freq = np.sum(data.data_full, axis=1) != 0.

# just to be sure, loop over data and ensure channels aren't "bad".
for idx_freq in range(data.num_freq):
if data.good_freq[idx_freq]:
if data.data_full[idx_freq, :].min() == data.data_full[idx_freq, :].max():
print(f"ERROR: bad data value of {data.data_full[idx_freq, :].min()} in channel {idx_freq}!")
data.good_freq[idx_freq] = False

if preprocess_data:
data.preprocess_data(normalize_variance=True, variance_range=variance_range)
data.preprocess_data(normalize_variance=False, variance_range=variance_range)

print(f"There are {data.good_freq.sum()} good frequencies...")

Expand All @@ -332,6 +359,7 @@

# check if any initial guesses are missing, and overload 'basic' guess value if so.
initial_parameters = data.burst_parameters
num_components = len(initial_parameters["dm"])
basic_parameters = {
"amplitude" : [-2.0],
"arrival_time" : [np.mean(data.times)],
Expand All @@ -349,9 +377,15 @@

if len(current_list) == 0:
print(f"WARNING: parameter '{current_parameter}' has no value in data file, overloading a basic guess...")
initial_parameters[current_parameter] = basic_parameters[current_parameter]
initial_parameters[current_parameter] = basic_parameters[current_parameter] * num_components

# now see if any parameters are missing in the dictionary.
for current_parameter in basic_parameters.keys():
if current_parameter not in initial_parameters:
initial_parameters[current_parameter] = basic_parameters[current_parameter] * num_components

current_parameters = deepcopy(initial_parameters)
print(f"INFO: current parameters = {current_parameters}")

# update DM value to use ("full" or DM offset) for dedispersion if
# input data are already dedispersed or not.
Expand All @@ -360,14 +394,15 @@
if data.is_dedispersed:
print("INFO: input data cube is already dedispersed!")
print("INFO: setting 'dm' entry to 0, now considered a dm-offset parameter...")
current_parameters["dm"][0] = 0.0
current_parameters["dm"] = [0.0] * len(initial_parameters["dm"])

if not remove_dispersion_smearing:
dm_incoherent = 0.

# if an existing solution is supplied in a JSON file, then read it or use basic guesses.
if existing_results is not None:
current_parameters = existing_results["model_parameters"]
num_components = len(current_parameters["dm"])

# if values are supplied at command line, then overload those here.
if amplitude is not None:
Expand Down Expand Up @@ -421,6 +456,14 @@
window=window
)

# before instantiating model, run peak-finding algorithm if desired.
if peakfind_rms is not None:
print("INFO: running FindPeak to isolate burst components...")
peaks = FindPeak(data_windowed, times_windowed, data.freqs, rms=peakfind_rms)
peaks.find_peak(distance=peakfind_dist)
current_parameters = peaks.get_parameters_dict(current_parameters)
num_components = len(current_parameters["dm"])

# now create initial model.
print("INFO: initializing model")
model = SpectrumModeler(
Expand All @@ -429,6 +472,7 @@
dm_incoherent = dm_incoherent,
factor_freq_upsample = factor_freq_upsample,
factor_time_upsample = factor_time_upsample,
num_components = num_components,
is_dedispersed = data.is_dedispersed,
is_folded = is_folded,
scintillation = scintillation,
Expand Down
12 changes: 10 additions & 2 deletions fitburst/utilities/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ def preprocess_data(self, normalize_variance: bool = True, skewness_range: list
mask_freq = np.sum(self.data_weights, -1)
good_freq = mask_freq != 0

# just to be sure, loop over data and ensure channels aren't "bad".
for idx_freq in range(self.num_freq):
if good_freq[idx_freq]:
if self.data_full[idx_freq, :].min() == self.data_full[idx_freq, :].max():
print(f"ERROR: bad data value of {self.data_full[idx_freq, :].min()} in channel {idx_freq}!")
good_freq[idx_freq] = False

# normalize data and remove baseline.
mean_spectrum = np.sum(self.data_full * self.data_weights, -1)
#good_freq[np.where(mean_spectrum == 0.)] = False
Expand All @@ -242,14 +249,15 @@ def preprocess_data(self, normalize_variance: bool = True, skewness_range: list
self.data_full[np.logical_not(self.data_weights)] = 0

# compute variance and skewness of data.
variance = np.sum(self.data_full**2, -1)
variance = np.sum(self.data_full ** 2, -1)
variance[good_freq] /= mask_freq[good_freq]
skewness = np.sum(self.data_full**3, -1)
skewness = np.sum(self.data_full ** 3, -1)
skewness[good_freq] /= mask_freq[good_freq]
skewness_mean = np.mean(skewness[good_freq])
skewness_std = np.std(skewness[good_freq])

# if desired, normalize variance relative to maximum value.

if normalize_variance:
variance[good_freq] /= np.max(variance[good_freq])

Expand Down

0 comments on commit 18f2459

Please sign in to comment.