From 1aa5b25b09f6eb71a63d677fbd7fc2d1034b62d5 Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Wed, 31 Jan 2024 08:11:37 -0500 Subject: [PATCH] Support for resampling, plots and various improvemetns/bug fixes --- wirecell/plot/__main__.py | 82 +++++++++++++- wirecell/sigproc/response/persist.py | 9 +- wirecell/util/__main__.py | 57 ++++++++++ wirecell/util/cli.py | 7 +- wirecell/util/fileio.py | 8 +- wirecell/util/frames.py | 161 +++++++++++++++++++++++++++ wirecell/util/jsio.py | 22 +++- wirecell/util/lmn.py | 10 +- wirecell/util/plottools.py | 2 + 9 files changed, 342 insertions(+), 16 deletions(-) create mode 100644 wirecell/util/frames.py diff --git a/wirecell/plot/__main__.py b/wirecell/plot/__main__.py index d012800..02ebc6c 100644 --- a/wirecell/plot/__main__.py +++ b/wirecell/plot/__main__.py @@ -7,7 +7,9 @@ from wirecell.util import ario, plottools from wirecell.util.cli import log, context, jsonnet_loader, frame_input, image_output from wirecell.util import jsio - +from wirecell.util.functions import unitify, unitify_parse +from wirecell import units +from pathlib import Path import numpy import matplotlib.pyplot as plt @@ -96,7 +98,7 @@ def ntier_frames(cmap, output, files): @cli.command("frame") @click.option("-n", "--name", default="wave", type=click.Choice(["wave","spectra"]), - help="The frame plot type name [default=waf]") + help="The frame plot type name [default=wave]") @click.option("-t", "--tag", default="orig", help="The frame tag") @click.option("-u", "--unit", default="ADC", @@ -334,6 +336,82 @@ def plot_slice(ax, slc): out.savefig() +@cli.command("channels") +@click.option("-c", "--channel", multiple=True, default=(), required=True, + help="Specify channels, eg '1,2:4,5' are 1-5 but not 4") +@click.option("-t", "--trange", default=None, type=str, + help="limit time range, eg '0,3*us'") +@click.option("-f", "--frange", default=None, type=str, + help="limit frequency range, eg '0,100*kHz'") +@image_output +@click.argument("frame_files", nargs=-1) +def channels(output, channel, trange, frange, frame_files, **kwds): + ''' + Plot channels from multiple frame files. + + Frames need not have same sample period (tick). + + If --single put all on one plot, else per-channel plots are made + ''' + + from wirecell.util.frames import load as load_frames + + if trange: + trange = unitify_parse(trange) + if frange: + frange = unitify_parse(frange) + + channels = list() + for chan in channel: + for one in chan.split(","): + if ":" in one: + f,l = one.split(":") + channels += range(int(f), int(l)) + else: + channels.append(int(one)) + + # fixme: move this mess out of here + + frames = {ff: list(load_frames(ff)) for ff in frame_files} + + with output as out: + + for chan in channels: + + fig,axes = plt.subplots(nrows=1, ncols=2) + fig.suptitle(f'channel {chan}') + + for fname, frs in frames.items(): + stem = Path(fname).stem + for fr in frs: + wave = fr.waves(chan) + axes[0].set_title("waveforms") + axes[0].plot(fr.times/units.us, wave, drawstyle='steps') + if trange: + axes[0].set_xlim(trange[0]/units.us, trange[1]/units.us) + axes[0].set_xlabel("time [us]") + + axes[1].set_title("spectra") + axes[1].plot(fr.freqs_MHz, numpy.abs(numpy.fft.fft(wave)), + label=f'{fr.nticks}x{fr.period/units.ns:.0f}ns\n{stem}') + axes[1].set_yscale('log') + if frange: + axes[1].set_xlim(frange[0]/units.MHz, frange[1]/units.MHz) + else: + axes[1].set_xlim(0, fr.freqs_MHz[fr.nticks//2]) + axes[1].set_xlabel("frequency [MHz]") + axes[1].legend() + print(fr.nticks, fr.period/units.ns, fr.duration/units.us) + + + if not out.single: + out.savefig() + plt.clf() + if out.single: + out.savefig() + + + def main(): cli(obj=dict()) diff --git a/wirecell/sigproc/response/persist.py b/wirecell/sigproc/response/persist.py index e929e60..41a9ef0 100644 --- a/wirecell/sigproc/response/persist.py +++ b/wirecell/sigproc/response/persist.py @@ -244,10 +244,13 @@ def load_detector(name): ''' Load response(s) given a canonical detector name. ''' + if ".json" in name: + raise ValueError(f'detector name looks like a file name: {name}') - fields = detectors.load(name, "fields") - if not fields: - raise IOError(f'failed to load responses for detector "{name}"') + try: + fields = detectors.load(name, "fields") + except KeyError: + raise IOError(f'failed to load fields for detector "{name}"') if isinstance(fields, list): return [pod2schema(f) for f in fields] diff --git a/wirecell/util/__main__.py b/wirecell/util/__main__.py index 660694f..06d1f2f 100644 --- a/wirecell/util/__main__.py +++ b/wirecell/util/__main__.py @@ -1249,6 +1249,63 @@ def cmd_detectors(path): +@cli.command("resample") +@click.option("-t", "--tick", default='500*ns', + help="Resample the frame to have this sample period with units, eg '500*ns'") +@click.option("-o","--output", type=str, required=True, + help="Output filename") +@click.argument("framefile") +def resample(tick, output, framefile): + ''' + Resample a frame file + ''' + from . import ario, lmn + + Tr = unitify(tick) + print(f'resample to {Tr/units.ns}ns to {output}') + + + fp = ario.load(framefile) + f_names = [k for k in fp.keys() if k.startswith("frame_")] + c_names = [k for k in fp.keys() if k.startswith("channels_")] + t_names = [k for k in fp.keys() if k.startswith("tickinfo_")] + + out = dict() + + for fnum, frame_name in enumerate(f_names): + _, suffix = frame_name.split('_',1) + ti = fp[f'tickinfo_{suffix}'] + Ts = ti[1] + + if Tr == Ts: + print(f'frame {fnum} "{frame_name}" already sampled at {Tr}') + continue + + frame = fp[frame_name] + Ns = frame.shape[1] + ss = lmn.Sampling(T=Ts, N=Ns) + + Nr = round(Ns * Tr / Ts) + sr = lmn.Sampling(T=Tr, N=Nr) + + print(f'{fnum} {frame_name} {ss=} -> {sr=}') + + resampled = numpy.zeros((frame.shape[0], Nr), dtype=frame.dtype) + for irow, row in enumerate(frame): + sig = lmn.Signal(ss, wave=row) + resig = lmn.interpolate(sig, Tr) + wave = resig.wave + # if Nr != wave.size: + # print(f'resizing to min({Nr=},{wave.size=})') + Nend = min(Nr, wave.size) + resampled[irow,:Nend] = wave[:Nend] + + out[f'frame_{suffix}'] = resampled + out[f'tickinfo_{suffix}'] = numpy.array([ti[0], Tr, ti[2]]) + out[f'channels_{suffix}'] = fp[f'channels_{suffix}'] + + numpy.savez_compressed(output, **out) + @cli.command("resolve") @click.option("-p","--path", default=(), multiple=True, help="Add a search path") diff --git a/wirecell/util/cli.py b/wirecell/util/cli.py index 27a30e5..3ed835e 100644 --- a/wirecell/util/cli.py +++ b/wirecell/util/cli.py @@ -263,9 +263,12 @@ def wrapper(*args, **kwds): single = kwds.pop("single", None) if single: - kwds["output"] = plottools.NameSingleton(output, format=fmt) + out = plottools.NameSingleton(output, format=fmt) + out.single = True else: - kwds["output"] = plottools.pages(output, format=fmt) + out = plottools.pages(output, format=fmt) + out.single = False + kwds["output"] = out kwds["cmap"] = colormaps[kwds["cmap"]] diff --git a/wirecell/util/fileio.py b/wirecell/util/fileio.py index 43f4901..749c594 100644 --- a/wirecell/util/fileio.py +++ b/wirecell/util/fileio.py @@ -8,11 +8,13 @@ from pathlib import Path # fixme: more generic path functions are in jsio which should move here -def wirecell_path(): +def wirecell_path(env=os.environ): ''' - Return list of paths from WIRECELL_PATH. + Return list of paths from WIRECELL_PATH environment variable. + + - env :: A specific environment dictionary else os.environ is used. ''' - return tuple(os.environ.get("WIRECELL_PATH","").split(":")) + return tuple(env.get("WIRECELL_PATH","").split(":")) def source_type(name): ''' diff --git a/wirecell/util/frames.py b/wirecell/util/frames.py new file mode 100644 index 0000000..7e85f24 --- /dev/null +++ b/wirecell/util/frames.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +''' +Some helpers for frame like objects +''' + +from . import ario +import numpy +import dataclasses +from wirecell import units + +@dataclasses.dataclass +class Frame: + + samples : numpy.ndarray | None = None + ''' + Frame samples as 2D array shape (nchans, nticks). + ''' + + channels : numpy.ndarray | None = None + ''' + Array of channel identity numbers. + ''' + + period: float | None = None + ''' + The time-domain sampling period (aka "tick"). + ''' + + tref: float = 0.0 + ''' + The reference time. + ''' + + tbin: int = 0 + ''' + The time bin represented by the first column of samples array. + ''' + + name: str = "" + ''' + Some human identifier. + ''' + + @property + def nchans(self): + ''' + Number of channels + ''' + return self.samples.shape[0] + + @property + def chan_bounds(self): + ''' + Pair of min/max channel number + ''' + return (numpy.min(self.channels), numpy.max(self.channels)) + + @property + def nticks(self): + ''' + Number of samples in time + ''' + return self.samples.shape[1] + + @property + def tstart(self): + ''' + The time of the first sample + ''' + return self.tref + self.period*self.tbin + + @property + def duration(self): + ''' + The time spanned by the samples + ''' + return self.nticks*self.period + + @property + def absolute_times(self): + ''' + An array of absolute times of samples + ''' + t0 = self.tstart + tf = self.duration + return numpy.linspace(t0, tf, self.nticks, endpoint=False) + + @property + def times(self): + ''' + An array of times of samples relative to first + ''' + t0 = 0 + tf = self.duration + return numpy.linspace(t0, tf, self.nticks, endpoint=False) + + @property + def Fmax_hz(self): + ''' + Max sampling frequency in hz + ''' + T_s = self.period/units.s + return 1/T_s + + @property + def freqs_MHz(self): + ''' + An array of frequencies in MHz of Fourier-domain samples + ''' + return numpy.linspace(0, self.Fmax_hz/1e6, self.nticks, endpoint=False) + + @property + def chids2row(self): + ''' + Return a channel ID to row index + ''' + return {c:i for i,c in enumerate(self.channels)} + + def waves(self, chans): + ''' + Return waveform rows for channel ID or sequence of IDs in chans. + ''' + scalar = False + if isinstance(chans, int): + chans = [chans] + scalar = True + + lu = self.chids2row + + nchans = len(chans) + shape = (nchans, self.nticks) + ret = numpy.zeros(shape, dtype=self.samples.dtype) + for ind,ch in enumerate(chans): + row = lu[ch] + ret[ind,:] = self.samples[row] + if scalar: + return ret[0] + return ret + + def __str__(self): + return f'({self.nchans}ch,{self.nticks}x{self.period/units.ns:.0f}ns) @ {self.tstart/units.us:.0f}us' + +def load(fp): + ''' + Yield frame objects in fp. + + fp is file name as string or pathlib.Path or ario/numpy.load() like. + ''' + if isinstance(fp, str): + fp = ario.load(fp) + + frame_names = [key for key in fp if key.startswith("frame_")] + for frame_name in frame_names: + _,tag,num = frame_name.split("_") + + ti = fp[f'tickinfo_{tag}_{num}'] + ch = fp[f'channels_{tag}_{num}'] + + yield Frame(samples=fp[frame_name], channels=ch, + period=ti[1], tref=ti[0], tbin=int(ti[2]), + name=frame_name) diff --git a/wirecell/util/jsio.py b/wirecell/util/jsio.py index efc4345..d41a2fa 100644 --- a/wirecell/util/jsio.py +++ b/wirecell/util/jsio.py @@ -8,6 +8,7 @@ import json import gzip from pathlib import Path +from wirecell.util.fileio import wirecell_path def jsonnet_module(): try: @@ -37,13 +38,25 @@ def clean_paths(paths, add_cwd=True): return paths -def resolve(filename, paths=()): +def resolve(filename, paths=(), env=True): '''Resolve filename against built-in directories and any user-provided list in "paths". + Elements of paths may be string or pathlib.Path. + Raise ValueError if fail. + If env is a dictionary use it as environment to find more paths. If env is + True use os.environ. Else, environment not considered. + ''' + paths = list(paths) + + if env is True: + paths += wirecell_path() + elif isinstance(env, dict): + paths += wirecell_path(env) + if not filename: raise RuntimeError("no file name provided") if isinstance(filename, str): @@ -51,10 +64,11 @@ def resolve(filename, paths=()): if filename.exists() and filename.is_absolute(): return filename + for maybe in clean_paths(paths): - maybe = Path(maybe) / filename - if maybe.exists(): - return maybe + path = Path(maybe) / filename + if path.exists(): + return path raise RuntimeError(f"file name {filename} not resolved in paths {paths}") diff --git a/wirecell/util/lmn.py b/wirecell/util/lmn.py index da535f0..4eba7f5 100644 --- a/wirecell/util/lmn.py +++ b/wirecell/util/lmn.py @@ -1,6 +1,8 @@ #!/usr/bin/env python ''' The LMN resampling method from the paper . + +fixme: this currently only supports signals in the form of 1D arrays. ''' import numpy @@ -8,6 +10,7 @@ from numpy import pi import dataclasses import matplotlib.pyplot as plt +from wirecell.util.cli import debug @dataclasses.dataclass class Sampling: @@ -222,9 +225,9 @@ def rational(sig, Tr, eps=1e-6): if not nrag: return sig - cur = sig.wave.copy() npad = nrat - nrag Ns += npad + cur = sig.wave.copy() cur.resize(Ns) ss = Sampling(Ts, cur.size) @@ -330,6 +333,7 @@ def interpolate(sig, Tr, eps=1e-6, name=None): ''' rat = rational(sig, Tr, eps) + # debug(f'interpolate: rationalize {sig.sampling} -> {rat.sampling}') Nr = rat.sampling.duration / Tr if abs(Nr - round(Nr)) > eps: @@ -337,13 +341,15 @@ def interpolate(sig, Tr, eps=1e-6, name=None): Nr = round(Nr) res = resample(rat, Nr) + # debug(f'interpolate: resample {rat.sampling} -> {res.sampling}') # rez = resize(res, sig.sampling.duration) rez = res # The response is instantaneous current and thus we use interpolation # normalization. - norm = res.sampling.N / sig.sampling.N + # norm = res.sampling.N / sig.sampling.N + norm = res.sampling.N / rat.sampling.N fin = norm * rez.wave return Signal(Sampling(T=Tr, N=fin.size), wave=fin, name=name) diff --git a/wirecell/util/plottools.py b/wirecell/util/plottools.py index 5a8a910..7b6699f 100644 --- a/wirecell/util/plottools.py +++ b/wirecell/util/plottools.py @@ -24,6 +24,7 @@ def rescaley(ax, x, y, rx, extra=0.1): class NameSequence(object): + def __init__(self, name, first=0, **kwds): ''' Every time called, emit a new name with an index. @@ -80,6 +81,7 @@ def __exit__(self, typ, value, traceback): class NameSingleton(object): + def __init__(self, path, **kwds): ''' Like a NameSequence but force a singleton.