Skip to content

Commit

Permalink
Merge pull request #47 from FRBs/spline_debug
Browse files Browse the repository at this point in the history
Fixed spline interp problem
  • Loading branch information
profxj authored Jan 10, 2024
2 parents 3e891f1 + 86b8fee commit 5d343cb
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 9 deletions.
42 changes: 35 additions & 7 deletions zdm/energetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,43 @@
SplineMin = -6
SplineMax = 6
NSpline = 1000
SplineLog = True

############## this section defines different luminosity functions ##########

def init_igamma_splines(gammas, reinit=False):
global SplineMin,SplineMax,NSpline
def init_igamma_splines(gammas, reinit=False,k=3):
"""
gammas [list of floats]: list of values of gamma at which splines
must be created
reinit [bool]: if True, will re-initialise even if a spline for
that gamma has already been created
k [int]: degree of spline to use. 3 by default (cubic splines).
Formal range: integers 1 <= k <= 5. Do NOT use 2 or 4.
If SplineLog is set, interpolations are performed in log-space,
i.e. the results is a spline interpolation of the log10 of the
answer in terms of the log10 of the input
"""
global SplineMin,SplineMax,NSpline,SplineLog
for gamma in gammas:
if gamma not in igamma_splines.keys() or reinit:
print(f"Initializing igamma_spline for gamma={gamma}")
avals = 10**np.linspace(SplineMin, SplineMax, NSpline)
lavals = np.linspace(SplineMin, SplineMax, NSpline)
avals = 10**lavals
numer = np.array([float(mpmath.gammainc(
gamma, a=iEE)) for iEE in avals])
# iGamma
igamma_splines[gamma] = interpolate.splrep(avals, numer,k=3)

if SplineLog:
# check for literal zeros, set them to small values
zero = np.where(numer == 0.)[0]
ismall = zero[0]-1
smallest = numer[ismall]
numer[zero] = smallest
lnumer = np.log10(numer)
igamma_splines[gamma] = interpolate.splrep(lavals, lnumer,k=k)
else:
igamma_splines[gamma] = interpolate.splrep(avals, numer,k=k)


def init_igamma_linear(gammas:list, reinit:bool=False,
log:bool=False):
""" Setup the linear interpolator for gamma
Expand Down Expand Up @@ -203,6 +226,8 @@ def vector_cum_gamma_spline(Eth:np.ndarray, *params):
Returns:
np.ndarray: [description]
"""
global SplineLog

params=np.array(params)
Emin=params[0]
Emax=params[1]
Expand All @@ -213,7 +238,10 @@ def vector_cum_gamma_spline(Eth:np.ndarray, *params):
Eth_Emax = Eth/Emax
if gamma not in igamma_splines.keys():
init_igamma_splines([gamma])
numer = interpolate.splev(Eth_Emax, igamma_splines[gamma])
if SplineLog:
numer = 10**interpolate.splev(np.log10(Eth_Emax), igamma_splines[gamma])
else:
numer = interpolate.splev(Eth_Emax, igamma_splines[gamma])
result=numer/norm

# Low end
Expand Down
4 changes: 2 additions & 2 deletions zdm/survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,14 +865,14 @@ def calc_relative_sensitivity(DM_frb,DM,w,fbar,t_res,nu_res,Nchan=336,max_idt=No

# Set sensitivity to 0 above the maximum searched DM
if max_dm != None:
sensitivity[DM > max_dm] = 1e-2 # Effectively 0 but not small enough to break it...
sensitivity[DM > max_dm] = 1e-5 # Effectively 0 but not small enough to break it...
if max_idt != None:
f_low = fbar - (Nchan/2. - 1)*nu_res
f_high = fbar + (Nchan/2. - 1)*nu_res
max_dt = t_res * max_idt # FREDDA searches up to 4096 time bins
max_dm2 = max_dt / (k_DM * ((f_low/1e3)**(-2) - (f_high/1e3)**(-2)))

sensitivity[DM > max_dm2] = 1e-2 # Effectively 0 but not small enough to break it...
sensitivity[DM > max_dm2] = 1e-5 # Effectively 0 but not small enough to break it...

# If model not CHIME, Quadrature or Sammons assume it is a filename
else:
Expand Down
287 changes: 287 additions & 0 deletions zdm/tests/test_splines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""
This script produces plots to test spline accuracy,
and test the time taken to create and evaluate splines.
"""

from zdm import energetics
import mpmath
import numpy as np
from scipy import interpolate

def main():


# this compares the accuracy of the default spline method, and
# complains if anything if wrong by more than 0.01 %
max_rdiff,lowest = test_default_spline()
assert(max_rdiff < 1e-4)
assert(lowest >= 0.)

# this creates a bunch of plots to test the accuracy of different spline methods
# off by default
#test_spline_accuracy()

# this times how long it takes to produce splines
# off by default
#time_splines(Nreps=1)

def test_default_spline(gamma=-1.5837, amin=1e-6, amax=1e6,
Ntest=1000, plot=False):
"""
Function to test the accuracy of the spline interplation
It checks the accuracy vs direct calculation
It does *not* check the accuracy of the underlying Python routine
Input:
gamma [float]: value of the index gamma to test at
amin [float]: minimum value relative to Emax at which to calculate
amax [float]: maximum value relative to Emax at which to calculate
Ntest [int]: number of random values to generate
plot [bool]: turn on to show a quick plot of results
Returns:
phys_worst: worst relative error in the "physical" range where
the true result is > 1e-200
min: absolute minimum value, to check this never goes negative
"""

# completely arbitrary choices of Emin and Emax
Emin = 1e30
Emax = 1e42

# the range over which to generate test values of Eth
lEthmin = np.log10(Emax) + np.log10(amin)
lEthmax = np.log10(Emax) + np.log10(amax)


# generate random values in the range of the splines
lEth = np.random.rand(Ntest) * \
(lEthmax - lEthmin) + lEthmin
Eth = 10**lEth
Eth = np.sort(Eth)

# generate a default spline
energetics.init_igamma_splines([gamma])
# evaluate that spline
result = energetics.vector_cum_gamma_spline(Eth,Emin,Emax,gamma)
# evaluate the true values
truth = energetics.vector_cum_gamma(Eth,Emin,Emax,gamma)

rdiff = np.abs(result - truth)/truth

# gets the greatest error in the "physical" range (> 1e-200)
phys = np.where(truth > 1e-200)[0]
phys_rdiff = rdiff[phys]

phys_worst = np.max(phys_rdiff)

# gets the minimum value over the entire range
lowest = np.min(result)

if plot:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(Eth,truth,linestyle='',marker='x')

plt.plot(Eth,result,linestyle='',marker='o')
plt.xscale('log')
plt.yscale('log')
plt.show()

return phys_worst, lowest


def test_spline_accuracy(gamma = -1.5837, Ntest = 100):
"""
Function to test the accuracy of the spline interplation
It explores different methods of spline interpolation
Ntest [int]: number of random values to generate
gamma [float]: value of the index gamma to test at
Output: This produces three plots, being
spline_test.pdf: compares splines to truth
spline_errors.pdf: plots absolute value of errors
spline_rel_errors.pdf: plots absolute value of relative errors
"""

# generate random values in the range of the splines
lnewavals = np.random.rand(Ntest) * \
(energetics.SplineMax - energetics.SplineMin) + energetics.SplineMin
lnewavals = np.sort(lnewavals)
newavals = 10**lnewavals


# generate the true values at these avalues via direct calculation
truth = np.zeros(Ntest)
for i,av in enumerate(newavals):
truth[i] = mpmath.gammainc(gamma, a=av)

# sets up for plotting results
from matplotlib import pyplot as plt
plt.figure()
ax1=plt.gca()
plt.plot(newavals,truth,label='truth')
plt.xlabel('a')
plt.ylabel('cumulative gamma function')
plt.xscale('log')
plt.yscale('log')

# sets up for plotting differences between truth and interpolation
plt.figure()
ax2=plt.gca()
plt.xlabel('a')
plt.ylabel('|interpolation - truth|')
plt.xscale('log')
plt.yscale('log')

# sets up for plotting relative differences between truth and interpolation
plt.figure()
ax3=plt.gca()
plt.xlabel('a')
plt.ylabel('|interpolation - truth|/truth')
plt.xscale('log')
plt.yscale('log')

# now use different spline methods to evaluate
# standard method: cubic, linear

energetics.SplineLog=False
energetics.init_igamma_splines([gamma],reinit=True,k=3)
result = interpolate.splev(newavals, energetics.igamma_splines[gamma])
ax1.plot(newavals,result,label='cubic spline, linear')
diff = np.abs(result - truth)
ax2.plot(newavals,diff,label='cubic spline, linear')
rdiff = diff/truth
ax3.plot(newavals,rdiff,label='cubic spline, linear')

# now use different spline methods to evaluate
# standard method: cubic, linear
energetics.init_igamma_splines([gamma],reinit=True,k=1)
result = interpolate.splev(newavals, energetics.igamma_splines[gamma])
ax1.plot(newavals,result,label='linear spline, linear')
diff = np.abs(result - truth)
ax2.plot(newavals,diff,label='linear spline, linear')
rdiff = diff/truth
ax3.plot(newavals,rdiff,label='linear spline, linear')


energetics.SplineLog=True
energetics.init_igamma_splines([gamma],reinit=True,k=3)
result = 10**interpolate.splev(lnewavals, energetics.igamma_splines[gamma])
ax1.plot(newavals,result,label='cubic spline, log')
diff = np.abs(result - truth)
ax2.plot(newavals,diff,label='cubic spline, log')
rdiff = diff/truth
ax3.plot(newavals,rdiff,label='cubic spline, log')


# now use different spline methods to evaluate
# standard method: cubic, linear
energetics.init_igamma_splines([gamma],reinit=True,k=1)
result = 10**interpolate.splev(lnewavals, energetics.igamma_splines[gamma])
ax1.plot(newavals,result,label='linear spline, log')
diff = np.abs(result - truth)
ax2.plot(newavals,diff,label='linear spline, log')
rdiff = diff/truth
ax3.plot(newavals,rdiff,label='linear spline, log')

plt.sca(ax1)
plt.legend()
plt.tight_layout()
plt.savefig('spline_test.pdf')
plt.close()

plt.sca(ax2)
plt.legend()
plt.tight_layout()
plt.savefig('spline_errors.pdf')

plt.sca(ax3)
plt.legend()
plt.ylim(1e-14,1)
plt.tight_layout()
plt.savefig('spline_rel_errors.pdf')
plt.close()


def time_splines(gamma = -1.5837, Ntimetest = 100000, Nreps=100):
"""
Function to time different methods of spline interpolation.
It explores different methods of spline interpolation
gamma [float]: value of the index gamma to test at
Ntimetest [int]: number of random values to generate to test on
Nreps [int]: number of repetitions for timing
"""
import time

# begins with very large array of values to evaluate on
lnewavals = np.random.rand(Ntimetest) * \
(energetics.SplineMax - energetics.SplineMin) + energetics.SplineMin
newavals = 10**lnewavals

# now use different spline methods to evaluate
# standard method: cubic, linear

energetics.SplineLog=False
t0=time.time()
for i in np.arange(Nreps):
energetics.init_igamma_splines([gamma],reinit=True,k=3)
t1=time.time()
for i in np.arange(Nreps):
result = interpolate.splev(newavals, energetics.igamma_splines[gamma])
t2=time.time()
dt1_ci = t1-t0
dt2_ci = t2-t1

# now use different spline methods to evaluate
# standard method: cubic, linear
t0=time.time()
for i in np.arange(Nreps):
energetics.init_igamma_splines([gamma],reinit=True,k=1)
t1=time.time()
for i in np.arange(Nreps):
result = interpolate.splev(newavals, energetics.igamma_splines[gamma])
t2=time.time()
dt1_li = t1-t0
dt2_li = t2-t1

energetics.SplineLog=True
t0=time.time()
for i in np.arange(Nreps):
energetics.init_igamma_splines([gamma],reinit=True,k=3)
t1=time.time()
for i in np.arange(Nreps):
result = 10**interpolate.splev(lnewavals, energetics.igamma_splines[gamma])
t2=time.time()
dt1_co = t1-t0
dt2_co = t2-t1

# now use different spline methods to evaluate
# standard method: cubic, linear
t0=time.time()
for i in np.arange(Nreps):
energetics.init_igamma_splines([gamma],reinit=True,k=1)
t1=time.time()
for i in np.arange(Nreps):
result = 10**interpolate.splev(lnewavals, energetics.igamma_splines[gamma])
t2=time.time()
dt1_lo = t1-t0
dt2_lo = t2-t1

print("Performing ",Nreps," spline initialisations took...")
print("Cubic spline in linear space: ",dt1_ci)
print("Linear spline in linear space: ",dt1_li)
print("Cubic spline in log space: ",dt1_co)
print("Linear spline in log space: ",dt1_lo)

print("Performing ",Nreps," x ",Ntimetest," spline evaluations took...")
print("Cubic spline in linear space: ",dt2_ci)
print("Linear spline in linear space: ",dt2_li)
print("Cubic spline in log space: ",dt2_co)
print("Linear spline in log space: ",dt2_lo)


main()

0 comments on commit 5d343cb

Please sign in to comment.