Skip to content

Commit

Permalink
Merge pull request #97 from icecube/init_model_fix
Browse files Browse the repository at this point in the history
Init model fix
  • Loading branch information
sgriswol authored Oct 30, 2023
2 parents aba968e + 5f8d26f commit 48914b1
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 58 deletions.
66 changes: 25 additions & 41 deletions python/asteria/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,37 +258,12 @@ def get_combined_spectrum(self, t, E, flavor, mixing):
spectrum : np.ndarray
Mixed neutrino spectrum as a 2D array with dim (time, energy)
"""
# TODO: Check that this function still works when p_surv and pc_osc are arrays
# TODO: Simplify after adding neutrino oscillates_to property to SNEWPY
# The cflavor "complementary flavor" is the flavor that the provided argument `flavor` oscillates to/from
if flavor.is_neutrino:
if flavor.is_electron:
coeffs = mixing.prob_ee(t, E), mixing.prob_ex(t, E)
cflavor = Flavor.NU_X
else:
coeffs = mixing.prob_xx(t, E), mixing.prob_xe(t, E)
cflavor = Flavor.NU_E
else:
if flavor.is_electron:
coeffs = mixing.prob_eebar(t, E), mixing.prob_exbar(t, E)
cflavor = Flavor.NU_X_BAR
else:
coeffs = mixing.prob_xxbar(t, E), mixing.prob_xebar(t, E)
cflavor = Flavor.NU_E_BAR

nu_spectrum = np.zeros(shape=(t.size, E.size))
for coeff, _flavor in zip(coeffs, (flavor, cflavor)):
alpha = self.source.alpha(t, _flavor)
meanE = self.source.meanE(t, _flavor).to(u.MeV).value

alpha[alpha < 0] = 0
cut = (alpha >= 0) & (meanE > 0)

flux = self.source.flux(t[cut], _flavor).value.reshape(-1, 1)
nu_spectrum[cut] += coeff * self.source.energy_pdf(t[cut], E, _flavor) * flux

photon_spectrum = self._photon_spectra[flavor].to(u.m ** 2).value.reshape(1, -1)
return nu_spectrum * photon_spectrum
nu_spectrum = self.source.model.get_transformed_spectra(t, E, mixing)[flavor].reshape(t.size, E.size)
cut = (t < self.source.model.time[0]) | (self.source.model.time[-1] < t)
# TODO: Apply a fix for this once a fix has been applied to SNEWPY
nu_spectrum[cut] = 0
photon_spectrum = self._photon_spectra[flavor].reshape(1, E.size)
return nu_spectrum.to(1 / u.MeV / u.s).value * photon_spectrum.to(u.m ** 2).value

def compute_energy_per_vol(self, *, part_size=1000):
"""Compute the energy deposited in a cubic meter of ice by photons
Expand All @@ -315,23 +290,32 @@ def compute_energy_per_vol(self, *, part_size=1000):
self._E_per_V = {}
self._total_E_per_V = np.zeros(self.time.size)

if self.source.special_compat_mode:
part_size = 1 # Done for certain SNEWPY Models until a fix has been applied
# These models can only return spectra for 1 time per call.
# TODO: Fix this once a fix has been applied to SNEWPY

# Perform core calculation on partitions in E to regulate memory usage in vectorized function
# Maximum usage is expected to be ~8MB
for flavor in self.flavors:
# Perform core calculation on partitions in E to regulate memory usage in vectorized function
# Maximum usage is expected to be ~8MB
result = np.zeros(self.time.size)
idx = 0

# Perform integration over spectrum
if part_size < self.time.size:
while idx + part_size < self.time.size:
for idx in np.arange(0, self.time.size, part_size):
spectrum = self.get_combined_spectrum(self.time[idx:idx + part_size], self.energy, flavor,
self._mixing)
result[idx:idx + part_size] = np.trapz(spectrum, self.energy.value, axis=1)
idx += part_size
spectrum = self.get_combined_spectrum(self.time[idx:], self.energy, flavor, self._mixing)
result[idx:] = np.trapz(spectrum, self.energy.value, axis=1)
# Add distance, density and time-binning scaling factors
result *= H2O_in_ice / (4 * np.pi * dist**2) * np.ediff1d(self.time,
to_end=(self.time[-1] - self.time[-2])).value
if not flavor.is_electron:

result *= (
H2O_in_ice * # Target Molecule (H2O) density
np.ediff1d(self.time, to_end=(self.time[-1] - self.time[-2])).value * # Time bin scaling
1 / (4 * np.pi * dist ** 2) # Distance
)

if not flavor.is_electron: # nu_x/nu_x_bar consist of nu_mu(_bar) & nu_tau(_bar), so double them
# TODO: Double check that the models describe single flavor spectrum or multi-flavor spectrum
result *= 2
self._E_per_V.update({flavor: result * (u.MeV / u.m / u.m / u.m)})
self._total_E_per_V += result
Expand Down
78 changes: 61 additions & 17 deletions python/asteria/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,53 @@

import astropy.units as u
import numpy as np
import logging



class Source:

def __init__(self, model, model_params=None):
"""CCSN Soruce Object. Contains methods describing neutrino emission
Parameters
----------
model : str
Name of SNEWPY Model
model_params : dict
SNEWPY Model parameters
"""
self.model = init_model(model, **model_params)
self._interp_lum = {}
self._interp_meanE = {}
self._interp_pinch = {}
self._special_compat_mode = False
t = self.model.time

if all([hasattr(self.model, attr) for attr in ("luminosity", "meanE", "pinch")]):
for flavor in Flavor:
self._interp_lum.update({flavor: PchipInterpolator(t, self.model.luminosity[flavor], extrapolate=False)})
self._interp_meanE.update({flavor: PchipInterpolator(t, self.model.meanE[flavor], extrapolate=False)})
self._interp_pinch.update({flavor: PchipInterpolator(t, self.model.pinch[flavor], extrapolate=False)})
else:
logging.warning(f"Model '{self.model.__class__.__name__}' lacks one or more of the following attributes: "
"'luminosity', 'meanE', 'pinch'. Some ASTERIA methods may not function fully")

if self.model.__class__.__name__ in ('Fornax_2019', 'Fornax_2021'):
logging.warning(f"Model '{self.model.__class__.__name__}' detected. Special compatibility mode enabled.\n"
"Expect a reduction in performance and increase in simulation run times.", )
self._special_compat_mode = True

for flavor in Flavor:
t = self.model.time
self._interp_lum.update({flavor: PchipInterpolator(t, self.model.luminosity[flavor], extrapolate=False)})
self._interp_meanE.update({flavor: PchipInterpolator(t, self.model.meanE[flavor], extrapolate=False)})
self._interp_pinch.update({flavor: PchipInterpolator(t, self.model.pinch[flavor], extrapolate=False)})
@property
def special_compat_mode(self):
"""Indicates whether A special compatibility mode should be used for parsing certain SNEWPY Models
Currently required for (For
Returns
-------
"""
return self._special_compat_mode

def luminosity(self, t, flavor=Flavor.NU_E_BAR):
"""Return interpolated source luminosity at time t for a given flavor.
Expand All @@ -56,7 +88,10 @@ def luminosity(self, t, flavor=Flavor.NU_E_BAR):
luminosity : Astropy.units.quantity.Quantity
Source luminosity (units of power).
"""
return np.nan_to_num(self._interp_lum[flavor](t)) * (u.erg / u.s)
if self._interp_lum:
return np.nan_to_num(self._interp_lum[flavor](t)) * (u.erg / u.s)
else:
raise NotImplementedError('Source is missing `luminosity` interpolator!')

def meanE(self, t, flavor=Flavor.NU_E_BAR):
"""Return interpolated source mean energy at time t for a given flavor.
Expand All @@ -75,7 +110,10 @@ def meanE(self, t, flavor=Flavor.NU_E_BAR):
Source mean energy (units of energy).
"""
# TODO Checks for units/unitless inputs
return np.nan_to_num(self._interp_meanE[flavor](t)) * u.MeV
if self._interp_meanE:
return np.nan_to_num(self._interp_meanE[flavor](t)) * u.MeV
else:
raise NotImplementedError('Source is missing `meanE` interpolator!')

def alpha(self, t, flavor=Flavor.NU_E_BAR):
"""Return source pinching paramter alpha at time t for a given flavor.
Expand Down Expand Up @@ -110,19 +148,22 @@ def flux(self, t, flavor=Flavor.NU_E_BAR):
flux :
Source number flux (unit-less, count of neutrinos).
"""
L = self.luminosity(t, flavor).to(u.MeV / u.s).value
meanE = self.meanE(t, flavor).value
if self._interp_meanE and self._interp_lum:
L = self.luminosity(t, flavor).to(u.MeV / u.s).value
meanE = self.meanE(t, flavor).value

if isinstance(t, np.ndarray):
_flux = np.divide(L, meanE, where=(meanE > 0), out=np.zeros(L.size))
else:
# TODO: Fix case where t is list, or non astropy quantity. This is a front-end function for some use cases
if meanE > 0.:
_flux = L / meanE
if isinstance(t, np.ndarray):
_flux = np.divide(L, meanE, where=(meanE > 0), out=np.zeros(L.size))
else:
_flux = 0
# TODO: Fix case where t is list, or non astropy quantity. This is a front-end function for some use cases
if meanE > 0.:
_flux = L / meanE
else:
_flux = 0

return _flux / u.s
return _flux / u.s
else:
raise NotImplementedError('Source is missing `meanE` and.or `luminosity` interpolator!')

@staticmethod
def _energy_pdf(a, Ea, E):
Expand All @@ -134,6 +175,9 @@ def _energy_cdf(a, Ea, E):
return gdtr(1., a + 1., (a + 1.) * (E / Ea))

def energy_pdf(self, t, E, flavor=Flavor.NU_E_BAR, *, limit_size=True):
if not self._interp_lum or not self._interp_pinch or not self._interp_meanE:
raise NotImplementedError('Source is missing `meanE`, `pinch` and/or `luminosity` interpolator!\n'
'Please use SNEWPY Model method `get_initial_spectrum`.')
_E = E.to(u.MeV).value
if isinstance(E, np.ndarray):
if _E[0] == 0:
Expand Down

0 comments on commit 48914b1

Please sign in to comment.