-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added MCMC2 which initialises on every step
- Loading branch information
1 parent
29bb663
commit 02b7ede
Showing
6 changed files
with
282 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
""" | ||
File: MCMC.py | ||
Author: Jordan Hoffmann | ||
Date: 06/12/23 | ||
Purpose: | ||
Contains functions used for MCMC runs of the zdm code. MCMC_wrap.py is the | ||
main function which does the calling and this holds functions which do the | ||
MCMC analysis. | ||
""" | ||
|
||
import numpy as np | ||
|
||
import zdm.iteration as it | ||
from pkg_resources import resource_filename | ||
|
||
import emcee | ||
import scipy.stats as st | ||
import time | ||
|
||
from zdm import loading | ||
from zdm import parameters | ||
|
||
from astropy.cosmology import Planck18 | ||
|
||
import multiprocessing as mp | ||
|
||
from zdm.misc_functions import * | ||
|
||
#============================================================================== | ||
|
||
def calc_log_posterior(param_vals, params, surveys_sep, grid_params): | ||
""" | ||
Calculates the log posterior for a given set of parameters. Assumes uniform | ||
priors between the minimum and maximum values provided in 'params'. | ||
Inputs: | ||
param_vals = Array of the parameter values for this step | ||
params = Dictionary of the parameter names, min and max values | ||
files = Object containing survey_names, rep_survey_names, sdir, edir | ||
Outputs: | ||
llsum = Total log likelihood for param_vals which is equivalent | ||
to log posterior (un-normalised) due to uniform priors | ||
""" | ||
|
||
# t0 = time.time() | ||
# Can use likelihoods instead of posteriors because we only use uniform priors which just changes normalisation of posterior | ||
# given every value is in the correct range. If any value is not in the correct range, log posterior is -inf | ||
in_priors = True | ||
param_dict = {} | ||
|
||
for i, (key,val) in enumerate(params.items()): | ||
if param_vals[i] < val['min'] or param_vals[i] > val['max']: | ||
in_priors = False | ||
break | ||
|
||
param_dict[key] = param_vals[i] | ||
|
||
if in_priors == False: | ||
llsum = -np.inf | ||
else: | ||
|
||
# minimise_const_only does the grid updating so we don't need to do it explicitly beforehand | ||
try: | ||
# Set state | ||
state = parameters.State() | ||
state.update_params(param_dict) | ||
state.set_astropy_cosmo(Planck18) | ||
|
||
# Initialise surveys and grids | ||
grids = [] | ||
if len(surveys_sep[0]) != 0: | ||
zDMgrid, zvals,dmvals = get_zdm_grid( | ||
state, new=True, plot=False, method='analytic', | ||
nz=grid_params['nz'], ndm=grid_params['ndm'], dmmax=grid_params['dmmax'], | ||
datdir=resource_filename('zdm', 'GridData')) | ||
|
||
# generates zdm grid | ||
grids += initialise_grids(surveys_sep[0], zDMgrid, zvals, dmvals, state, wdist=True, repeaters=False) | ||
|
||
if len(surveys_sep[0]) != 0: | ||
zDMgrid, zvals,dmvals = get_zdm_grid( | ||
state, new=True, plot=False, method='analytic', | ||
nz=grid_params['nz'], ndm=grid_params['ndm'], dmmax=grid_params['dmmax'], | ||
datdir=resource_filename('zdm', 'GridData')) | ||
|
||
# generates zdm grid | ||
grids += initialise_grids(surveys_sep[1], zDMgrid, zvals, dmvals, state, wdist=True, repeaters=True) | ||
surveys = surveys_sep[0] + surveys_sep[1] | ||
|
||
# Minimse the constant accross all surveys | ||
newC, llC = it.minimise_const_only(None, grids, surveys) | ||
for g in grids: | ||
g.state.FRBdemo.lC = newC | ||
|
||
# calculate all the likelihoods | ||
llsum = 0 | ||
for s, grid in zip(surveys, grids): | ||
llsum += it.get_log_likelihood(grid,s) | ||
|
||
except ValueError as e: | ||
print("ValueError, setting likelihood to -inf: " + str(e)) | ||
llsum = -np.inf | ||
|
||
if np.isnan(llsum): | ||
print("llsum was NaN. Setting to -infinity", param_dict) | ||
llsum = -np.inf | ||
|
||
# print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True) | ||
|
||
return llsum | ||
|
||
#============================================================================== | ||
|
||
def mcmc_runner(logpf, outfile, params, surveys, grid_params, nwalkers=10, nsteps=100, nthreads=1): | ||
""" | ||
Handles the MCMC running. | ||
Inputs: | ||
logpf = Log posterior function handle | ||
outfile = Name of the output file (excluding .h5 extension) | ||
params = Dictionary of the parameter names, min and max values | ||
surveys = List of surveys being used | ||
grids = List of grids corresponding to the surveys | ||
nwalkers = Number of walkers | ||
nsteps = Number of steps per walker | ||
nthreads = Number of threads to use for parallelised runs | ||
Outputs: | ||
posterior_sample = Final sample | ||
outfile.h5 = HDF5 file containing the sampler | ||
""" | ||
|
||
ndim = len(params) | ||
starting_guesses = [] | ||
|
||
# Produce starting guesses for each parameter | ||
for key,val in params.items(): | ||
starting_guesses.append(st.uniform(loc=val['min'], scale=val['max']-val['min']).rvs(size=[nwalkers])) | ||
print(key + " priors: " + str(val['min']) + "," + str(val['max'])) | ||
|
||
starting_guesses = np.array(starting_guesses).T | ||
|
||
backend = emcee.backends.HDFBackend(outfile+'.h5') | ||
backend.reset(nwalkers, ndim) | ||
|
||
start = time.time() | ||
with mp.Pool() as pool: | ||
sampler = emcee.EnsembleSampler(nwalkers, ndim, logpf, args=[params, surveys, grid_params], backend=backend, pool=pool) | ||
sampler.run_mcmc(starting_guesses, nsteps, progress=True) | ||
end = time.time() | ||
print("Total time taken: " + str(end - start)) | ||
|
||
posterior_sample = sampler.get_chain() | ||
|
||
return posterior_sample | ||
|
||
#============================================================================== |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
""" | ||
File: MCMC_wrap.py | ||
Author: Jordan Hoffmann | ||
Date: 28/09/23 | ||
Purpose: | ||
Wrapper file to run MCMC analysis for zdm code. Handles command line | ||
parameters and loading of surveys, grids and parameters. Actual MCMC | ||
analysis functions are in MCMC.py. | ||
scripts/run_mcmc.slurm contains an example sbatch script. | ||
""" | ||
|
||
import argparse | ||
import os | ||
|
||
from zdm import survey | ||
from zdm import cosmology as cos | ||
from zdm import loading | ||
from zdm.MCMC2 import * | ||
|
||
import pickle | ||
import json | ||
|
||
#============================================================================== | ||
|
||
def main(): | ||
""" | ||
Handles the setup for MCMC runs. This involves reading / creating the | ||
surveys and grids, reading the parameters and prior ranges and then | ||
beginning the MCMC run. | ||
Inputs: | ||
args = Command line parameters | ||
Outputs: | ||
None | ||
""" | ||
|
||
# Parsing command line | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-f', '--files', default=None, nargs='+', type=str, help="Survey file names") | ||
parser.add_argument('-r', '--rep_surveys', default=None, nargs='+', type=str, help="Surveys to consider repeaters in") | ||
parser.add_argument('-p','--pfile', default=None , type=str, help="File defining parameter ranges") | ||
parser.add_argument('-o','--opfile', default=None, type=str, help="Output file for the data") | ||
parser.add_argument('-w', '--walkers', default=20, type=int, help="Number of MCMC walkers") | ||
parser.add_argument('-s', '--steps', default=100, type=int, help="Number of MCMC steps") | ||
parser.add_argument('-n', '--nthreads', default=1, type=int, help="Number of threads") | ||
parser.add_argument('--sdir', default=None, type=str, help="Directory containing surveys") | ||
parser.add_argument('--edir', default=None, type=str, help="Directory containing efficiency files") | ||
parser.add_argument('--outdir', default="", type=str, help="Output directory") | ||
args = parser.parse_args() | ||
|
||
# Check correct flags are specified | ||
if args.pfile is None or args.opfile is None: | ||
print("-p and -o flags are required") | ||
exit() | ||
|
||
# Initialise surveys | ||
surveys = [[], []] | ||
state = parameters.State() | ||
|
||
grid_params = {} | ||
grid_params['dmmax'] = 7000.0 | ||
grid_params['ndm'] = 1400 | ||
grid_params['nz'] = 500 | ||
ddm = grid_params['dmmax'] / grid_params['ndm'] | ||
dmvals = (np.arange(grid_params['ndm']) + 1) * ddm | ||
|
||
if args.files is not None: | ||
for survey_name in args.files: | ||
s = survey.load_survey(survey_name, state, dmvals, | ||
sdir=args.sdir, edir=args.edir) | ||
surveys[0].append(s) | ||
|
||
if args.rep_surveys is not None: | ||
for survey_name in args.rep_surveys: | ||
s = survey.load_survey(survey_name, state, dmvals, | ||
sdir=args.sdir, edir=args.edir) | ||
surveys[1].append(s) | ||
|
||
# Make output directory | ||
if args.outdir != "" and not os.path.exists(args.outdir): | ||
os.mkdir(args.outdir) | ||
|
||
with open(args.pfile) as f: | ||
mcmc_dict = json.load(f) | ||
|
||
# Select from dictionary the necessary parameters to be changed | ||
params = {k: mcmc_dict[k] for k in mcmc_dict['mcmc']['parameter_order']} | ||
|
||
mcmc_runner(calc_log_posterior, os.path.join(args.outdir, args.opfile), params, surveys, grid_params, nwalkers=args.walkers, nsteps=args.steps, nthreads=args.nthreads) | ||
|
||
#============================================================================== | ||
|
||
main() |
Oops, something went wrong.