Skip to content

Commit

Permalink
Support for resampling, plots and various improvemetns/bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brettviren committed Jan 31, 2024
1 parent f3836a5 commit 1aa5b25
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 16 deletions.
82 changes: 80 additions & 2 deletions wirecell/plot/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from wirecell.util import ario, plottools
from wirecell.util.cli import log, context, jsonnet_loader, frame_input, image_output
from wirecell.util import jsio

from wirecell.util.functions import unitify, unitify_parse
from wirecell import units
from pathlib import Path
import numpy
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -96,7 +98,7 @@ def ntier_frames(cmap, output, files):
@cli.command("frame")
@click.option("-n", "--name", default="wave",
type=click.Choice(["wave","spectra"]),
help="The frame plot type name [default=waf]")
help="The frame plot type name [default=wave]")
@click.option("-t", "--tag", default="orig",
help="The frame tag")
@click.option("-u", "--unit", default="ADC",
Expand Down Expand Up @@ -334,6 +336,82 @@ def plot_slice(ax, slc):
out.savefig()


@cli.command("channels")
@click.option("-c", "--channel", multiple=True, default=(), required=True,
help="Specify channels, eg '1,2:4,5' are 1-5 but not 4")
@click.option("-t", "--trange", default=None, type=str,
help="limit time range, eg '0,3*us'")
@click.option("-f", "--frange", default=None, type=str,
help="limit frequency range, eg '0,100*kHz'")
@image_output
@click.argument("frame_files", nargs=-1)
def channels(output, channel, trange, frange, frame_files, **kwds):
'''
Plot channels from multiple frame files.
Frames need not have same sample period (tick).
If --single put all on one plot, else per-channel plots are made
'''

from wirecell.util.frames import load as load_frames

if trange:
trange = unitify_parse(trange)
if frange:
frange = unitify_parse(frange)

channels = list()
for chan in channel:
for one in chan.split(","):
if ":" in one:
f,l = one.split(":")
channels += range(int(f), int(l))
else:
channels.append(int(one))

# fixme: move this mess out of here

frames = {ff: list(load_frames(ff)) for ff in frame_files}

with output as out:

for chan in channels:

fig,axes = plt.subplots(nrows=1, ncols=2)
fig.suptitle(f'channel {chan}')

for fname, frs in frames.items():
stem = Path(fname).stem
for fr in frs:
wave = fr.waves(chan)
axes[0].set_title("waveforms")
axes[0].plot(fr.times/units.us, wave, drawstyle='steps')
if trange:
axes[0].set_xlim(trange[0]/units.us, trange[1]/units.us)
axes[0].set_xlabel("time [us]")

axes[1].set_title("spectra")
axes[1].plot(fr.freqs_MHz, numpy.abs(numpy.fft.fft(wave)),
label=f'{fr.nticks}x{fr.period/units.ns:.0f}ns\n{stem}')
axes[1].set_yscale('log')
if frange:
axes[1].set_xlim(frange[0]/units.MHz, frange[1]/units.MHz)
else:
axes[1].set_xlim(0, fr.freqs_MHz[fr.nticks//2])
axes[1].set_xlabel("frequency [MHz]")
axes[1].legend()
print(fr.nticks, fr.period/units.ns, fr.duration/units.us)


if not out.single:
out.savefig()
plt.clf()
if out.single:
out.savefig()



def main():
cli(obj=dict())

Expand Down
9 changes: 6 additions & 3 deletions wirecell/sigproc/response/persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,13 @@ def load_detector(name):
'''
Load response(s) given a canonical detector name.
'''
if ".json" in name:
raise ValueError(f'detector name looks like a file name: {name}')

fields = detectors.load(name, "fields")
if not fields:
raise IOError(f'failed to load responses for detector "{name}"')
try:
fields = detectors.load(name, "fields")
except KeyError:
raise IOError(f'failed to load fields for detector "{name}"')

if isinstance(fields, list):
return [pod2schema(f) for f in fields]
Expand Down
57 changes: 57 additions & 0 deletions wirecell/util/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,63 @@ def cmd_detectors(path):



@cli.command("resample")
@click.option("-t", "--tick", default='500*ns',
help="Resample the frame to have this sample period with units, eg '500*ns'")
@click.option("-o","--output", type=str, required=True,
help="Output filename")
@click.argument("framefile")
def resample(tick, output, framefile):
'''
Resample a frame file
'''
from . import ario, lmn

Tr = unitify(tick)
print(f'resample to {Tr/units.ns}ns to {output}')


fp = ario.load(framefile)
f_names = [k for k in fp.keys() if k.startswith("frame_")]
c_names = [k for k in fp.keys() if k.startswith("channels_")]
t_names = [k for k in fp.keys() if k.startswith("tickinfo_")]

out = dict()

for fnum, frame_name in enumerate(f_names):
_, suffix = frame_name.split('_',1)
ti = fp[f'tickinfo_{suffix}']
Ts = ti[1]

if Tr == Ts:
print(f'frame {fnum} "{frame_name}" already sampled at {Tr}')
continue

frame = fp[frame_name]
Ns = frame.shape[1]
ss = lmn.Sampling(T=Ts, N=Ns)

Nr = round(Ns * Tr / Ts)
sr = lmn.Sampling(T=Tr, N=Nr)

print(f'{fnum} {frame_name} {ss=} -> {sr=}')

resampled = numpy.zeros((frame.shape[0], Nr), dtype=frame.dtype)
for irow, row in enumerate(frame):
sig = lmn.Signal(ss, wave=row)
resig = lmn.interpolate(sig, Tr)
wave = resig.wave
# if Nr != wave.size:
# print(f'resizing to min({Nr=},{wave.size=})')
Nend = min(Nr, wave.size)
resampled[irow,:Nend] = wave[:Nend]

out[f'frame_{suffix}'] = resampled
out[f'tickinfo_{suffix}'] = numpy.array([ti[0], Tr, ti[2]])
out[f'channels_{suffix}'] = fp[f'channels_{suffix}']

numpy.savez_compressed(output, **out)

@cli.command("resolve")
@click.option("-p","--path", default=(), multiple=True,
help="Add a search path")
Expand Down
7 changes: 5 additions & 2 deletions wirecell/util/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,12 @@ def wrapper(*args, **kwds):

single = kwds.pop("single", None)
if single:
kwds["output"] = plottools.NameSingleton(output, format=fmt)
out = plottools.NameSingleton(output, format=fmt)
out.single = True
else:
kwds["output"] = plottools.pages(output, format=fmt)
out = plottools.pages(output, format=fmt)
out.single = False
kwds["output"] = out

kwds["cmap"] = colormaps[kwds["cmap"]]

Expand Down
8 changes: 5 additions & 3 deletions wirecell/util/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from pathlib import Path

# fixme: more generic path functions are in jsio which should move here
def wirecell_path():
def wirecell_path(env=os.environ):
'''
Return list of paths from WIRECELL_PATH.
Return list of paths from WIRECELL_PATH environment variable.
- env :: A specific environment dictionary else os.environ is used.
'''
return tuple(os.environ.get("WIRECELL_PATH","").split(":"))
return tuple(env.get("WIRECELL_PATH","").split(":"))

def source_type(name):
'''
Expand Down
161 changes: 161 additions & 0 deletions wirecell/util/frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python
'''
Some helpers for frame like objects
'''

from . import ario
import numpy
import dataclasses
from wirecell import units

@dataclasses.dataclass
class Frame:

samples : numpy.ndarray | None = None
'''
Frame samples as 2D array shape (nchans, nticks).
'''

channels : numpy.ndarray | None = None
'''
Array of channel identity numbers.
'''

period: float | None = None
'''
The time-domain sampling period (aka "tick").
'''

tref: float = 0.0
'''
The reference time.
'''

tbin: int = 0
'''
The time bin represented by the first column of samples array.
'''

name: str = ""
'''
Some human identifier.
'''

@property
def nchans(self):
'''
Number of channels
'''
return self.samples.shape[0]

@property
def chan_bounds(self):
'''
Pair of min/max channel number
'''
return (numpy.min(self.channels), numpy.max(self.channels))

@property
def nticks(self):
'''
Number of samples in time
'''
return self.samples.shape[1]

@property
def tstart(self):
'''
The time of the first sample
'''
return self.tref + self.period*self.tbin

@property
def duration(self):
'''
The time spanned by the samples
'''
return self.nticks*self.period

@property
def absolute_times(self):
'''
An array of absolute times of samples
'''
t0 = self.tstart
tf = self.duration
return numpy.linspace(t0, tf, self.nticks, endpoint=False)

@property
def times(self):
'''
An array of times of samples relative to first
'''
t0 = 0
tf = self.duration
return numpy.linspace(t0, tf, self.nticks, endpoint=False)

@property
def Fmax_hz(self):
'''
Max sampling frequency in hz
'''
T_s = self.period/units.s
return 1/T_s

@property
def freqs_MHz(self):
'''
An array of frequencies in MHz of Fourier-domain samples
'''
return numpy.linspace(0, self.Fmax_hz/1e6, self.nticks, endpoint=False)

@property
def chids2row(self):
'''
Return a channel ID to row index
'''
return {c:i for i,c in enumerate(self.channels)}

def waves(self, chans):
'''
Return waveform rows for channel ID or sequence of IDs in chans.
'''
scalar = False
if isinstance(chans, int):
chans = [chans]
scalar = True

lu = self.chids2row

nchans = len(chans)
shape = (nchans, self.nticks)
ret = numpy.zeros(shape, dtype=self.samples.dtype)
for ind,ch in enumerate(chans):
row = lu[ch]
ret[ind,:] = self.samples[row]
if scalar:
return ret[0]
return ret

def __str__(self):
return f'({self.nchans}ch,{self.nticks}x{self.period/units.ns:.0f}ns) @ {self.tstart/units.us:.0f}us'

def load(fp):
'''
Yield frame objects in fp.
fp is file name as string or pathlib.Path or ario/numpy.load() like.
'''
if isinstance(fp, str):
fp = ario.load(fp)

frame_names = [key for key in fp if key.startswith("frame_")]
for frame_name in frame_names:
_,tag,num = frame_name.split("_")

ti = fp[f'tickinfo_{tag}_{num}']
ch = fp[f'channels_{tag}_{num}']

yield Frame(samples=fp[frame_name], channels=ch,
period=ti[1], tref=ti[0], tbin=int(ti[2]),
name=frame_name)
Loading

0 comments on commit 1aa5b25

Please sign in to comment.