diff --git a/zdm/MCMC.py b/zdm/MCMC.py index 0e869f24..0c14c942 100644 --- a/zdm/MCMC.py +++ b/zdm/MCMC.py @@ -14,7 +14,6 @@ import emcee import scipy.stats as st -import pickle import time import multiprocessing as mp @@ -39,7 +38,7 @@ def calc_log_posterior(param_vals, params, surveys, grids): to log posterior (un-normalised) due to uniform priors """ - t0 = time.time() + # t0 = time.time() # Can use likelihoods instead of posteriors because we only use uniform priors which just changes normalisation of posterior # given every value is in the correct range. If any value is not in the correct range, log posterior is -inf in_priors = True @@ -65,11 +64,6 @@ def calc_log_posterior(param_vals, params, surveys, grids): # calculate all the likelihoods llsum = 0 for s, grid in zip(surveys, grids): - # if DMhalo != None: - # s.init_DMEG(DMhalo) - if 'DMhalo' in param_dict: - s.init_DMEG(param_dict['DMhalo']) - llsum += it.get_log_likelihood(grid,s) except ValueError as e: @@ -80,7 +74,7 @@ def calc_log_posterior(param_vals, params, surveys, grids): print("llsum was NaN. Setting to -infinity", param_dict) llsum = -np.inf - print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True) + # print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True) return llsum @@ -118,12 +112,12 @@ def mcmc_runner(logpf, outfile, params, surveys, grids, nwalkers=10, nsteps=100, backend = emcee.backends.HDFBackend(outfile+'.h5') backend.reset(nwalkers, ndim) - start = time.time() - with mp.Pool(nthreads) as pool: + # start = time.time() + with mp.Pool() as pool: sampler = emcee.EnsembleSampler(nwalkers, ndim, logpf, args=[params, surveys, grids], backend=backend, pool=pool) sampler.run_mcmc(starting_guesses, nsteps, progress=True) - end = time.time() - print("Total time taken: " + str(end - start)) + # end = time.time() + # print("Total time taken: " + str(end - start)) posterior_sample = sampler.get_chain() diff --git a/zdm/MCMC2.py b/zdm/MCMC2.py new file mode 100644 index 00000000..deb4f42d --- /dev/null +++ b/zdm/MCMC2.py @@ -0,0 +1,158 @@ +""" +File: MCMC.py +Author: Jordan Hoffmann +Date: 06/12/23 +Purpose: + Contains functions used for MCMC runs of the zdm code. MCMC_wrap.py is the + main function which does the calling and this holds functions which do the + MCMC analysis. +""" + +import numpy as np + +import zdm.iteration as it +from pkg_resources import resource_filename + +import emcee +import scipy.stats as st +import time + +from zdm import loading +from zdm import parameters + +from astropy.cosmology import Planck18 + +import multiprocessing as mp + +from zdm.misc_functions import * + +#============================================================================== + +def calc_log_posterior(param_vals, params, surveys_sep, grid_params): + """ + Calculates the log posterior for a given set of parameters. Assumes uniform + priors between the minimum and maximum values provided in 'params'. + + Inputs: + param_vals = Array of the parameter values for this step + params = Dictionary of the parameter names, min and max values + files = Object containing survey_names, rep_survey_names, sdir, edir + + Outputs: + llsum = Total log likelihood for param_vals which is equivalent + to log posterior (un-normalised) due to uniform priors + """ + + # t0 = time.time() + # Can use likelihoods instead of posteriors because we only use uniform priors which just changes normalisation of posterior + # given every value is in the correct range. If any value is not in the correct range, log posterior is -inf + in_priors = True + param_dict = {} + + for i, (key,val) in enumerate(params.items()): + if param_vals[i] < val['min'] or param_vals[i] > val['max']: + in_priors = False + break + + param_dict[key] = param_vals[i] + + if in_priors == False: + llsum = -np.inf + else: + + # minimise_const_only does the grid updating so we don't need to do it explicitly beforehand + try: + # Set state + state = parameters.State() + state.update_params(param_dict) + state.set_astropy_cosmo(Planck18) + + # Initialise surveys and grids + grids = [] + if len(surveys_sep[0]) != 0: + zDMgrid, zvals,dmvals = get_zdm_grid( + state, new=True, plot=False, method='analytic', + nz=grid_params['nz'], ndm=grid_params['ndm'], dmmax=grid_params['dmmax'], + datdir=resource_filename('zdm', 'GridData')) + + # generates zdm grid + grids += initialise_grids(surveys_sep[0], zDMgrid, zvals, dmvals, state, wdist=True, repeaters=False) + + if len(surveys_sep[0]) != 0: + zDMgrid, zvals,dmvals = get_zdm_grid( + state, new=True, plot=False, method='analytic', + nz=grid_params['nz'], ndm=grid_params['ndm'], dmmax=grid_params['dmmax'], + datdir=resource_filename('zdm', 'GridData')) + + # generates zdm grid + grids += initialise_grids(surveys_sep[1], zDMgrid, zvals, dmvals, state, wdist=True, repeaters=True) + surveys = surveys_sep[0] + surveys_sep[1] + + # Minimse the constant accross all surveys + newC, llC = it.minimise_const_only(None, grids, surveys) + for g in grids: + g.state.FRBdemo.lC = newC + + # calculate all the likelihoods + llsum = 0 + for s, grid in zip(surveys, grids): + llsum += it.get_log_likelihood(grid,s) + + except ValueError as e: + print("ValueError, setting likelihood to -inf: " + str(e)) + llsum = -np.inf + + if np.isnan(llsum): + print("llsum was NaN. Setting to -infinity", param_dict) + llsum = -np.inf + + # print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True) + + return llsum + +#============================================================================== + +def mcmc_runner(logpf, outfile, params, surveys, grid_params, nwalkers=10, nsteps=100, nthreads=1): + """ + Handles the MCMC running. + + Inputs: + logpf = Log posterior function handle + outfile = Name of the output file (excluding .h5 extension) + params = Dictionary of the parameter names, min and max values + surveys = List of surveys being used + grids = List of grids corresponding to the surveys + nwalkers = Number of walkers + nsteps = Number of steps per walker + nthreads = Number of threads to use for parallelised runs + + Outputs: + posterior_sample = Final sample + outfile.h5 = HDF5 file containing the sampler + """ + + ndim = len(params) + starting_guesses = [] + + # Produce starting guesses for each parameter + for key,val in params.items(): + starting_guesses.append(st.uniform(loc=val['min'], scale=val['max']-val['min']).rvs(size=[nwalkers])) + print(key + " priors: " + str(val['min']) + "," + str(val['max'])) + + starting_guesses = np.array(starting_guesses).T + + backend = emcee.backends.HDFBackend(outfile+'.h5') + backend.reset(nwalkers, ndim) + + start = time.time() + with mp.Pool() as pool: + sampler = emcee.EnsembleSampler(nwalkers, ndim, logpf, args=[params, surveys, grid_params], backend=backend, pool=pool) + sampler.run_mcmc(starting_guesses, nsteps, progress=True) + end = time.time() + print("Total time taken: " + str(end - start)) + + posterior_sample = sampler.get_chain() + + return posterior_sample + +#============================================================================== \ No newline at end of file diff --git a/zdm/MCMC_wrap.py b/zdm/MCMC_wrap.py index 49767a03..99d6c392 100644 --- a/zdm/MCMC_wrap.py +++ b/zdm/MCMC_wrap.py @@ -13,12 +13,9 @@ import argparse import os -from zdm import survey -from zdm import cosmology as cos from zdm import loading from zdm.MCMC import * -import pickle import json #============================================================================== @@ -38,8 +35,8 @@ def main(): # Parsing command line parser = argparse.ArgumentParser() - parser.add_argument('-f', '--files', default=None, nargs='?', type=commasep, help="Survey file names") - parser.add_argument('-r', '--rep_surveys', default=None, nargs='?', type=commasep, help="Surveys to consider repeaters in") + parser.add_argument('-f', '--files', default=None, nargs='+', type=str, help="Survey file names") + parser.add_argument('-r', '--rep_surveys', default=None, nargs='+', type=str, help="Surveys to consider repeaters in") parser.add_argument('-p','--pfile', default=None , type=str, help="File defining parameter ranges") parser.add_argument('-o','--opfile', default=None, type=str, help="Output file for the data") parser.add_argument('-w', '--walkers', default=20, type=int, help="Number of MCMC walkers") @@ -68,9 +65,6 @@ def main(): surveys.append(s) grids.append(g) - if len(surveys) == 0: - raise ValueError("No surveys to use!") - # Make output directory if args.outdir != "" and not os.path.exists(args.outdir): os.mkdir(args.outdir) @@ -84,21 +78,5 @@ def main(): mcmc_runner(calc_log_posterior, os.path.join(args.outdir, args.opfile), params, surveys, grids, nwalkers=args.walkers, nsteps=args.steps, nthreads=args.nthreads) #============================================================================== -""" -Function: commasep -Date: 23/08/2022 -Purpose: - Turn a string of variables seperated by commas into a list - -Imports: - s = String of variables - -Exports: - List conversion of s -""" -def commasep(s): - return list(map(str, s.split(','))) - -#============================================================================== - + main() \ No newline at end of file diff --git a/zdm/MCMC_wrap2.py b/zdm/MCMC_wrap2.py new file mode 100644 index 00000000..5aac2850 --- /dev/null +++ b/zdm/MCMC_wrap2.py @@ -0,0 +1,95 @@ +""" +File: MCMC_wrap.py +Author: Jordan Hoffmann +Date: 28/09/23 +Purpose: + Wrapper file to run MCMC analysis for zdm code. Handles command line + parameters and loading of surveys, grids and parameters. Actual MCMC + analysis functions are in MCMC.py. + + scripts/run_mcmc.slurm contains an example sbatch script. +""" + +import argparse +import os + +from zdm import survey +from zdm import cosmology as cos +from zdm import loading +from zdm.MCMC2 import * + +import pickle +import json + +#============================================================================== + +def main(): + """ + Handles the setup for MCMC runs. This involves reading / creating the + surveys and grids, reading the parameters and prior ranges and then + beginning the MCMC run. + + Inputs: + args = Command line parameters + + Outputs: + None + """ + + # Parsing command line + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--files', default=None, nargs='+', type=str, help="Survey file names") + parser.add_argument('-r', '--rep_surveys', default=None, nargs='+', type=str, help="Surveys to consider repeaters in") + parser.add_argument('-p','--pfile', default=None , type=str, help="File defining parameter ranges") + parser.add_argument('-o','--opfile', default=None, type=str, help="Output file for the data") + parser.add_argument('-w', '--walkers', default=20, type=int, help="Number of MCMC walkers") + parser.add_argument('-s', '--steps', default=100, type=int, help="Number of MCMC steps") + parser.add_argument('-n', '--nthreads', default=1, type=int, help="Number of threads") + parser.add_argument('--sdir', default=None, type=str, help="Directory containing surveys") + parser.add_argument('--edir', default=None, type=str, help="Directory containing efficiency files") + parser.add_argument('--outdir', default="", type=str, help="Output directory") + args = parser.parse_args() + + # Check correct flags are specified + if args.pfile is None or args.opfile is None: + print("-p and -o flags are required") + exit() + + # Initialise surveys + surveys = [[], []] + state = parameters.State() + + grid_params = {} + grid_params['dmmax'] = 7000.0 + grid_params['ndm'] = 1400 + grid_params['nz'] = 500 + ddm = grid_params['dmmax'] / grid_params['ndm'] + dmvals = (np.arange(grid_params['ndm']) + 1) * ddm + + if args.files is not None: + for survey_name in args.files: + s = survey.load_survey(survey_name, state, dmvals, + sdir=args.sdir, edir=args.edir) + surveys[0].append(s) + + if args.rep_surveys is not None: + for survey_name in args.rep_surveys: + s = survey.load_survey(survey_name, state, dmvals, + sdir=args.sdir, edir=args.edir) + surveys[1].append(s) + + # Make output directory + if args.outdir != "" and not os.path.exists(args.outdir): + os.mkdir(args.outdir) + + with open(args.pfile) as f: + mcmc_dict = json.load(f) + + # Select from dictionary the necessary parameters to be changed + params = {k: mcmc_dict[k] for k in mcmc_dict['mcmc']['parameter_order']} + + mcmc_runner(calc_log_posterior, os.path.join(args.outdir, args.opfile), params, surveys, grid_params, nwalkers=args.walkers, nsteps=args.steps, nthreads=args.nthreads) + +#============================================================================== + +main() \ No newline at end of file diff --git a/zdm/grid.py b/zdm/grid.py index b25bd85d..5b911b34 100644 --- a/zdm/grid.py +++ b/zdm/grid.py @@ -7,7 +7,7 @@ from zdm import energetics from zdm import pcosmic from zdm import io - +import time class Grid: """A class to hold a grid of z-dm plots @@ -18,7 +18,7 @@ class Grid: It also assumes a linear uniform grid. """ - def __init__(self, survey, state, zDMgrid, zvals, dmvals, smear_mask, wdist): + def __init__(self, survey, state, zDMgrid, zvals, dmvals, smear_mask, wdist, prev_grid=None): """ Class constructor. @@ -57,8 +57,15 @@ def __init__(self, survey, state, zDMgrid, zvals, dmvals, smear_mask, wdist): # Init the grid # THESE SHOULD BE THE SAME ORDER AS self.update() self.parse_grid(zDMgrid.copy(), zvals.copy(), dmvals.copy()) - self.calc_dV() - self.smear_dm(smear_mask.copy()) + + if prev_grid == None: + self.calc_dV() + self.smear_dm(smear_mask.copy()) + else: + self.dV = prev_grid.dV.copy() + self.smear = prev_grid.smear.copy() + self.smear_grid = prev_grid.smear_grid.copy() + if wdist: efficiencies = survey.efficiencies # two dimensions weights = survey.wplist @@ -749,6 +756,9 @@ def update(self, vparams: dict, ALL=False, prev_grid=None): calc_pdv = True new_pdv_smear = True + if self.chk_upd_param("DMhalo", vparams, update=True): + self.survey.init_DMEG(vparams["DMhalo"]) + # ########################### # NOW DO THE REAL WORK!! @@ -812,13 +822,11 @@ def update(self, vparams: dict, ALL=False, prev_grid=None): bandwidth=self.bandwidth, weights=self.eff_weights, ) - + if calc_pdv or ALL: self.calc_pdv() - if set_evol or ALL: self.set_evolution() # sets star-formation rate scaling with z - here, no evoltion... - if new_sfr_smear or ALL: self.calc_rates() # includes sfr smearing factors and pdv mult elif new_pdv_smear: diff --git a/zdm/misc_functions.py b/zdm/misc_functions.py index 7b4af9ad..027adabc 100644 --- a/zdm/misc_functions.py +++ b/zdm/misc_functions.py @@ -1942,18 +1942,20 @@ def initialise_grids( ) grids = [] for survey in surveys: - print(f"Working on {survey.name}") + prev_grid = None + # print(f"Working on {survey.name}") if repeaters: grid = zdm_repeat_grid.repeat_Grid( - survey, copy.deepcopy(state), zDMgrid, zvals, dmvals, mask, wdist + survey, copy.deepcopy(state), zDMgrid, zvals, dmvals, mask, wdist, prev_grid=prev_grid ) else: grid = zdm_grid.Grid( - survey, copy.deepcopy(state), zDMgrid, zvals, dmvals, mask, wdist + survey, copy.deepcopy(state), zDMgrid, zvals, dmvals, mask, wdist, prev_grid=prev_grid ) grids.append(grid) + prev_grid = grid return grids