Skip to content

Commit

Permalink
Added MCMC2 which initialises on every step
Browse files Browse the repository at this point in the history
  • Loading branch information
JordanHoffmann3 committed Jan 22, 2024
1 parent 29bb663 commit 02b7ede
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 47 deletions.
18 changes: 6 additions & 12 deletions zdm/MCMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import emcee
import scipy.stats as st
import pickle
import time

import multiprocessing as mp
Expand All @@ -39,7 +38,7 @@ def calc_log_posterior(param_vals, params, surveys, grids):
to log posterior (un-normalised) due to uniform priors
"""

t0 = time.time()
# t0 = time.time()
# Can use likelihoods instead of posteriors because we only use uniform priors which just changes normalisation of posterior
# given every value is in the correct range. If any value is not in the correct range, log posterior is -inf
in_priors = True
Expand All @@ -65,11 +64,6 @@ def calc_log_posterior(param_vals, params, surveys, grids):
# calculate all the likelihoods
llsum = 0
for s, grid in zip(surveys, grids):
# if DMhalo != None:
# s.init_DMEG(DMhalo)
if 'DMhalo' in param_dict:
s.init_DMEG(param_dict['DMhalo'])

llsum += it.get_log_likelihood(grid,s)

except ValueError as e:
Expand All @@ -80,7 +74,7 @@ def calc_log_posterior(param_vals, params, surveys, grids):
print("llsum was NaN. Setting to -infinity", param_dict)
llsum = -np.inf

print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True)
# print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True)

return llsum

Expand Down Expand Up @@ -118,12 +112,12 @@ def mcmc_runner(logpf, outfile, params, surveys, grids, nwalkers=10, nsteps=100,
backend = emcee.backends.HDFBackend(outfile+'.h5')
backend.reset(nwalkers, ndim)

start = time.time()
with mp.Pool(nthreads) as pool:
# start = time.time()
with mp.Pool() as pool:
sampler = emcee.EnsembleSampler(nwalkers, ndim, logpf, args=[params, surveys, grids], backend=backend, pool=pool)
sampler.run_mcmc(starting_guesses, nsteps, progress=True)
end = time.time()
print("Total time taken: " + str(end - start))
# end = time.time()
# print("Total time taken: " + str(end - start))

posterior_sample = sampler.get_chain()

Expand Down
158 changes: 158 additions & 0 deletions zdm/MCMC2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
File: MCMC.py
Author: Jordan Hoffmann
Date: 06/12/23
Purpose:
Contains functions used for MCMC runs of the zdm code. MCMC_wrap.py is the
main function which does the calling and this holds functions which do the
MCMC analysis.
"""

import numpy as np

import zdm.iteration as it
from pkg_resources import resource_filename

import emcee
import scipy.stats as st
import time

from zdm import loading
from zdm import parameters

from astropy.cosmology import Planck18

import multiprocessing as mp

from zdm.misc_functions import *

#==============================================================================

def calc_log_posterior(param_vals, params, surveys_sep, grid_params):
"""
Calculates the log posterior for a given set of parameters. Assumes uniform
priors between the minimum and maximum values provided in 'params'.
Inputs:
param_vals = Array of the parameter values for this step
params = Dictionary of the parameter names, min and max values
files = Object containing survey_names, rep_survey_names, sdir, edir
Outputs:
llsum = Total log likelihood for param_vals which is equivalent
to log posterior (un-normalised) due to uniform priors
"""

# t0 = time.time()
# Can use likelihoods instead of posteriors because we only use uniform priors which just changes normalisation of posterior
# given every value is in the correct range. If any value is not in the correct range, log posterior is -inf
in_priors = True
param_dict = {}

for i, (key,val) in enumerate(params.items()):
if param_vals[i] < val['min'] or param_vals[i] > val['max']:
in_priors = False
break

param_dict[key] = param_vals[i]

if in_priors == False:
llsum = -np.inf
else:

# minimise_const_only does the grid updating so we don't need to do it explicitly beforehand
try:
# Set state
state = parameters.State()
state.update_params(param_dict)
state.set_astropy_cosmo(Planck18)

# Initialise surveys and grids
grids = []
if len(surveys_sep[0]) != 0:
zDMgrid, zvals,dmvals = get_zdm_grid(
state, new=True, plot=False, method='analytic',
nz=grid_params['nz'], ndm=grid_params['ndm'], dmmax=grid_params['dmmax'],
datdir=resource_filename('zdm', 'GridData'))

# generates zdm grid
grids += initialise_grids(surveys_sep[0], zDMgrid, zvals, dmvals, state, wdist=True, repeaters=False)

if len(surveys_sep[0]) != 0:
zDMgrid, zvals,dmvals = get_zdm_grid(
state, new=True, plot=False, method='analytic',
nz=grid_params['nz'], ndm=grid_params['ndm'], dmmax=grid_params['dmmax'],
datdir=resource_filename('zdm', 'GridData'))

# generates zdm grid
grids += initialise_grids(surveys_sep[1], zDMgrid, zvals, dmvals, state, wdist=True, repeaters=True)
surveys = surveys_sep[0] + surveys_sep[1]

# Minimse the constant accross all surveys
newC, llC = it.minimise_const_only(None, grids, surveys)
for g in grids:
g.state.FRBdemo.lC = newC

# calculate all the likelihoods
llsum = 0
for s, grid in zip(surveys, grids):
llsum += it.get_log_likelihood(grid,s)

except ValueError as e:
print("ValueError, setting likelihood to -inf: " + str(e))
llsum = -np.inf

if np.isnan(llsum):
print("llsum was NaN. Setting to -infinity", param_dict)
llsum = -np.inf

# print("Posterior calc time: " + str(time.time()-t0) + " seconds", flush=True)

return llsum

#==============================================================================

def mcmc_runner(logpf, outfile, params, surveys, grid_params, nwalkers=10, nsteps=100, nthreads=1):
"""
Handles the MCMC running.
Inputs:
logpf = Log posterior function handle
outfile = Name of the output file (excluding .h5 extension)
params = Dictionary of the parameter names, min and max values
surveys = List of surveys being used
grids = List of grids corresponding to the surveys
nwalkers = Number of walkers
nsteps = Number of steps per walker
nthreads = Number of threads to use for parallelised runs
Outputs:
posterior_sample = Final sample
outfile.h5 = HDF5 file containing the sampler
"""

ndim = len(params)
starting_guesses = []

# Produce starting guesses for each parameter
for key,val in params.items():
starting_guesses.append(st.uniform(loc=val['min'], scale=val['max']-val['min']).rvs(size=[nwalkers]))
print(key + " priors: " + str(val['min']) + "," + str(val['max']))

starting_guesses = np.array(starting_guesses).T

backend = emcee.backends.HDFBackend(outfile+'.h5')
backend.reset(nwalkers, ndim)

start = time.time()
with mp.Pool() as pool:
sampler = emcee.EnsembleSampler(nwalkers, ndim, logpf, args=[params, surveys, grid_params], backend=backend, pool=pool)
sampler.run_mcmc(starting_guesses, nsteps, progress=True)
end = time.time()
print("Total time taken: " + str(end - start))

posterior_sample = sampler.get_chain()

return posterior_sample

#==============================================================================
28 changes: 3 additions & 25 deletions zdm/MCMC_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
import argparse
import os

from zdm import survey
from zdm import cosmology as cos
from zdm import loading
from zdm.MCMC import *

import pickle
import json

#==============================================================================
Expand All @@ -38,8 +35,8 @@ def main():

# Parsing command line
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--files', default=None, nargs='?', type=commasep, help="Survey file names")
parser.add_argument('-r', '--rep_surveys', default=None, nargs='?', type=commasep, help="Surveys to consider repeaters in")
parser.add_argument('-f', '--files', default=None, nargs='+', type=str, help="Survey file names")
parser.add_argument('-r', '--rep_surveys', default=None, nargs='+', type=str, help="Surveys to consider repeaters in")
parser.add_argument('-p','--pfile', default=None , type=str, help="File defining parameter ranges")
parser.add_argument('-o','--opfile', default=None, type=str, help="Output file for the data")
parser.add_argument('-w', '--walkers', default=20, type=int, help="Number of MCMC walkers")
Expand Down Expand Up @@ -68,9 +65,6 @@ def main():
surveys.append(s)
grids.append(g)

if len(surveys) == 0:
raise ValueError("No surveys to use!")

# Make output directory
if args.outdir != "" and not os.path.exists(args.outdir):
os.mkdir(args.outdir)
Expand All @@ -84,21 +78,5 @@ def main():
mcmc_runner(calc_log_posterior, os.path.join(args.outdir, args.opfile), params, surveys, grids, nwalkers=args.walkers, nsteps=args.steps, nthreads=args.nthreads)

#==============================================================================
"""
Function: commasep
Date: 23/08/2022
Purpose:
Turn a string of variables seperated by commas into a list
Imports:
s = String of variables
Exports:
List conversion of s
"""
def commasep(s):
return list(map(str, s.split(',')))

#==============================================================================


main()
95 changes: 95 additions & 0 deletions zdm/MCMC_wrap2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
File: MCMC_wrap.py
Author: Jordan Hoffmann
Date: 28/09/23
Purpose:
Wrapper file to run MCMC analysis for zdm code. Handles command line
parameters and loading of surveys, grids and parameters. Actual MCMC
analysis functions are in MCMC.py.
scripts/run_mcmc.slurm contains an example sbatch script.
"""

import argparse
import os

from zdm import survey
from zdm import cosmology as cos
from zdm import loading
from zdm.MCMC2 import *

import pickle
import json

#==============================================================================

def main():
"""
Handles the setup for MCMC runs. This involves reading / creating the
surveys and grids, reading the parameters and prior ranges and then
beginning the MCMC run.
Inputs:
args = Command line parameters
Outputs:
None
"""

# Parsing command line
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--files', default=None, nargs='+', type=str, help="Survey file names")
parser.add_argument('-r', '--rep_surveys', default=None, nargs='+', type=str, help="Surveys to consider repeaters in")
parser.add_argument('-p','--pfile', default=None , type=str, help="File defining parameter ranges")
parser.add_argument('-o','--opfile', default=None, type=str, help="Output file for the data")
parser.add_argument('-w', '--walkers', default=20, type=int, help="Number of MCMC walkers")
parser.add_argument('-s', '--steps', default=100, type=int, help="Number of MCMC steps")
parser.add_argument('-n', '--nthreads', default=1, type=int, help="Number of threads")
parser.add_argument('--sdir', default=None, type=str, help="Directory containing surveys")
parser.add_argument('--edir', default=None, type=str, help="Directory containing efficiency files")
parser.add_argument('--outdir', default="", type=str, help="Output directory")
args = parser.parse_args()

# Check correct flags are specified
if args.pfile is None or args.opfile is None:
print("-p and -o flags are required")
exit()

# Initialise surveys
surveys = [[], []]
state = parameters.State()

grid_params = {}
grid_params['dmmax'] = 7000.0
grid_params['ndm'] = 1400
grid_params['nz'] = 500
ddm = grid_params['dmmax'] / grid_params['ndm']
dmvals = (np.arange(grid_params['ndm']) + 1) * ddm

if args.files is not None:
for survey_name in args.files:
s = survey.load_survey(survey_name, state, dmvals,
sdir=args.sdir, edir=args.edir)
surveys[0].append(s)

if args.rep_surveys is not None:
for survey_name in args.rep_surveys:
s = survey.load_survey(survey_name, state, dmvals,
sdir=args.sdir, edir=args.edir)
surveys[1].append(s)

# Make output directory
if args.outdir != "" and not os.path.exists(args.outdir):
os.mkdir(args.outdir)

with open(args.pfile) as f:
mcmc_dict = json.load(f)

# Select from dictionary the necessary parameters to be changed
params = {k: mcmc_dict[k] for k in mcmc_dict['mcmc']['parameter_order']}

mcmc_runner(calc_log_posterior, os.path.join(args.outdir, args.opfile), params, surveys, grid_params, nwalkers=args.walkers, nsteps=args.steps, nthreads=args.nthreads)

#==============================================================================

main()
Loading

0 comments on commit 02b7ede

Please sign in to comment.