Skip to content

Commit

Permalink
Protect against failed fits
Browse files Browse the repository at this point in the history
  • Loading branch information
brettviren committed Feb 24, 2024
1 parent b041971 commit aa20f6e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 19 deletions.
61 changes: 48 additions & 13 deletions wirecell/test/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def cli(ctx):
'''
pass


@cli.command("plot")
@click.option("-n", "--name", default="noise",
help="The test name")
Expand All @@ -37,7 +38,6 @@ def plot(ctx, name, datafile, output):
fp = ario.load(datafile)
with plottools.pages(output) as out:
mod.plot(fp, out)



def ssss_args(func):
Expand All @@ -53,8 +53,11 @@ def ssss_args(func):
@functools.wraps(func)
def wrapper(*args, **kwds):

kwds["splat"] = ssss.load_frame(kwds.pop("splat"))
kwds["signal"] = ssss.load_frame(kwds.pop("signal"))
kwds["splat_filename"] = kwds.pop("splat")
kwds["signal_filename"] = kwds.pop("signal")

kwds["splat"] = ssss.load_frame(kwds["splat_filename"])
kwds["signal"] = ssss.load_frame(kwds["signal_filename"])

channel_ranges = kwds.pop("channel_ranges")
if channel_ranges:
Expand All @@ -74,6 +77,8 @@ def plot_ssss(channel_ranges, nsigma, nbins, splat, signal, output,
Perform the simple splat / sim+signal process comparison test and make plots.
'''

nminsig = 3 # sanity check

with pages(output) as out:

ssss.plot_frames(splat, signal, channel_ranges, title)
Expand All @@ -100,6 +105,13 @@ def plot_ssss(channel_ranges, nsigma, nbins, splat, signal, output,

spl_qch = numpy.sum(spl.activity[bbox], axis=1)
sig_qch = numpy.sum(sig.activity[bbox], axis=1)

nspl = len(spl_qch)
nsig = len(sig_qch)
if nspl != nsig or nsig < nminsig:
log.error(f'error: bad signals: {nspl=} {nsig=} {pln=} {ch=}')
raise ValueError(f'bad signals: {nspl=} {nsig=}')

byplane.append((spl_qch, sig_qch))


Expand Down Expand Up @@ -132,7 +144,14 @@ def ssss_metrics(channel_ranges, nsigma, nbins, splat, signal, output, params, *
spl_qch = numpy.sum(spl.activity[bbox], axis=1)
sig_qch = numpy.sum(sig.activity[bbox], axis=1)

m = ssss.calc_metrics(spl_qch, sig_qch, nbins)
try:
m = ssss.calc_metrics(spl_qch, sig_qch, nbins)
except Exception as err:
splat_filename = kwds['splat_filename']
signal_filename = kwds['signal_filename']
log.error(f'error: ({err}) failed to calculate metrics for {pln=} {ch=} {splat_filename=} {signal_filename=}')
m = ssss.Metrics()

metrics.append(dataclasses.asdict(m))

if params:
Expand All @@ -149,8 +168,10 @@ def ssss_metrics(channel_ranges, nsigma, nbins, splat, signal, output, params, *
help="PDF file in which to plot metrics")
@click.option("--coordinate-plane", default=None, type=int,
help="Use given plane number as global coordinates plane, default uses per-plane coordinates")
@click.option("-t","--title", default="",
help="The title string")
@click.argument("files",nargs=-1)
def plot_metrics(output, coordinate_plane, files):
def plot_metrics(output, coordinate_plane, title, files):
'''Plot per-plane metrics from files.
Files are as produced by ssss-metrics and must include a "params" key.
Expand All @@ -176,9 +197,15 @@ def add(k,v):

pmet = met[plane]
add('ineff', pmet['ineff'])
add('bias', pmet['fit']['avg'])
hi = pmet['fit']['hi']
lo = pmet['fit']['lo']
fit = pmet['fit']
if fit is None:
add('bias', 1) # fixme: best way to show failure?
add('reso', 1)
continue

add('bias', fit['avg'])
hi = fit['hi']
lo = fit['lo']
add('reso', 0.5*(hi+lo) )
continue;

Expand All @@ -193,9 +220,15 @@ def add(k,v):
add('ty', par['theta_y_wps'][plane])
add('txz', par['theta_xz_wps'][plane])
add('ineff', pmet['ineff'])
add('bias', pmet['fit']['avg'])
hi = pmet['fit']['hi']
lo = pmet['fit']['lo']
fit = pmet['fit']
if fit is None:
add('bias', 1) # fixme: best way to show failure?
add('reso', 1)
continue

add('bias', fit['avg'])
hi = fit['hi']
lo = fit['lo']
add('reso', 0.5*(hi+lo) )


Expand All @@ -207,11 +240,13 @@ def add(k,v):
pcolors = ('#58D453', '#7D99D1', '#D45853')

fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True)
if title:
title = ' - ' + title
if coordinate_plane is None:
fig.suptitle("Per-plane angles")
fig.suptitle("Per-plane angles" + title)
else:
letter = "UVW"[coordinate_plane]
fig.suptitle(f'Global angles ({letter}-plane)')
fig.suptitle(f'Global angles ({letter}-plane)' + title)

todeg = 180/numpy.pi
# xlabs = [f'{txz}/{ty}' for txz,ty in zip(
Expand Down
23 changes: 19 additions & 4 deletions wirecell/test/ssss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
baseline_noise,
gauss as gauss_func
)
import logging
log = logging.getLogger("wirecell.test")


def relbias(a,b):
'''
Expand Down Expand Up @@ -197,17 +200,17 @@ def plot_plane(spl_act, sig_act, nsigma=3.0, title=""):
class Metrics:
'''Metrics about a signal vs splat'''

neor: int
neor: int = 0
''' Number of channels with activity in either the signal or splat (or both)
and over which the rest are calculated. This can be less than the number of
channels in the original "activity" arrays if any given channel has zero
activity in both "signal" and "splat". '''

ineff: float
ineff: float = -1
''' The relative inefficiency. This is the fraction of channels with splat
but with zero signal. '''

fit: BaselineNoise
fit: BaselineNoise | None = None
'''
Gaussian fit to relative difference. .mu is bias and .sigma is resolution.
'''
Expand All @@ -220,6 +223,12 @@ def calc_metrics(spl_qch, sig_qch, nbins=50):
- nbins :: the number of bins over which to fit the relative difference.
'''

nspl = len(spl_qch)
nsig = len(sig_qch)

if nspl != nsig:
raise ValueError(f'length mismatch {nspl=} != {nsig=}')

# either-or, exclude channels where both are zero
eor = numpy.logical_or (spl_qch > 0, sig_qch > 0)
# both are nonzero
Expand Down Expand Up @@ -247,7 +256,13 @@ def plot_metrics(splat_signal_activity_pairs, nbins=50, title="", letters="UVW")
fig, axes = plt.subplots(nrows=2, ncols=3, sharey="row")
for pln, (spl_qch, sig_qch) in enumerate(splat_signal_activity_pairs):

m = calc_metrics(spl_qch, sig_qch, nbins)
try:
m = calc_metrics(spl_qch, sig_qch, nbins)
except:
log.error(f'error: failed to get metric for {pln=} {spl_qch.size=} {sig_qch.size=} {nbins=} {title=}')
log.debug(f'skipped splat: {spl_qch=}')
log.debug(f'skipped signal: {sig_qch=}')
continue
counts, edges = m.fit.hist
model = gauss_func(edges[:-1], m.fit.A, m.fit.mu, m.fit.sigma)

Expand Down
20 changes: 18 additions & 2 deletions wirecell/util/peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from wirecell.util.codec import dataclass_dictify
from wirecell.util.bbox import union as union_bbox

import logging
log = logging.getLogger("wirecell.util")

sqrt2pi = sqrt(2*pi)

def gauss(x, A, mu, sigma, *p):
Expand Down Expand Up @@ -45,6 +48,11 @@ class BaselineNoise:
Width (fit standard deviation)
'''

N : int
'''
Number of samples
'''

C : float
'''
Normalization (sum)
Expand Down Expand Up @@ -93,11 +101,18 @@ def baseline_noise(array, bins=200, vrange=100):
defines an extent about the MEDIAN VALUE. If it is a tuple it gives this
extent explicitly or if scalar the extent is symmetric, ie median+/-vrange.
This will raise exceptions:
- ZeroDivisionError when the signal in the vrange is zero.
- RuntimeError when the fit fails.
'''
nsig = len(array)
lo, med, hi = numpy.quantile(array, [0.5-0.34,0.5,0.5+0.34])

if not isinstance(vrange, tuple):
vrange=(-vrange, vrange)
vrange=(med-vrange, med+vrange)
vrange=(med+vrange[0], med+vrange[1])

hist = numpy.histogram(array, bins=bins, range=vrange)
Expand All @@ -113,11 +128,12 @@ def baseline_noise(array, bins=200, vrange=100):
(A,mu,sig),cov = curve_fit(gauss, edges[:-1], counts, p0=p0)
except RuntimeError:
cov = None

return BaselineNoise(A=A, mu=mu, sigma=sig,
N=nsig,
C=C, avg=avg, rms=rms,
med=med, lo=lo, hi=hi,
cov=cov, hist=hist)


@dataclasses.dataclass
@dataclass_dictify
Expand Down

0 comments on commit aa20f6e

Please sign in to comment.