Skip to content

Commit

Permalink
Fix and make more flexible the constructing of U-Net and DNNROI nets.
Browse files Browse the repository at this point in the history
And start adding trainer and tracker.
  • Loading branch information
brettviren committed Nov 4, 2024
1 parent 225eb46 commit 5898771
Show file tree
Hide file tree
Showing 12 changed files with 759 additions and 231 deletions.
123 changes: 112 additions & 11 deletions wirecell/dnn/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python3

import click

from wirecell.util.cli import context, log, jsonnet_loader
from wirecell.util.paths import unglob
from wirecell.util.paths import unglob, listify


@context("dnn")
def cli(ctx):
Expand All @@ -15,18 +17,110 @@ def cli(ctx):
@click.option("-c", "--config",
type=click.Path(),
help="Set configuration file")
@click.option("-e", "--epochs", default=1, help="Number of epochs over which to train")
@click.option("-b", "--batch", default=1, help="Batch size")
@click.argument("files", nargs=-1)
@click.pass_context
def train_dnnroi(ctx, config, files):
def train(ctx, config, epochs, batch, files):
'''
Train the DNNROI model.
'''
fpaths = unglob(files)
print (fpaths)
# fixme: args to explicitly select use of "flow" tracking.
from wirecell.dnn.tracker import flow

# fixme: make choice of dataset optional
from wirecell.dnn.apps import dnnroi as app



# fixme: this should all be moved under the app
ds = app.Dataset(unglob(files))
imshape = ds[0][0].shape[-2:]
print(f'{imshape=}')

from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=batch, shuffle=True)

net = app.Network(imshape, batch_size=batch)
trainer = app.Trainer(net, tracker=flow)

for epoch in range(epochs):
loss = trainer.epoch(dl)
flow.log_metric("epoch_loss", dict(epoch=epoch, loss=loss))

# log.info(config)



@cli.command('extract')
@click.option("-o", "--output", default='samples.npz',
help="Output in which to save the extracted samples") # fixme: support also hdf
@click.option("-s", "--sample", multiple=True, type=str,
help="Index or comma separated list of indices for samples to extract")
@click.argument("datapaths", nargs=-1)
@click.pass_context
def extract(ctx, output, sample, datapaths):
'''
Extract samples from a dataset.
The datapaths name files or file globs.
'''
samples = map(int,listify(*sample, delim=","))

# fixme: make choice of dataset optional
from wirecell.dnn.apps import dnnroi as app
ds = app.Dataset(datapaths)

print(f'dataset has {len(ds)} entries from {len(datapaths)} data paths')

# fixme: support npz and hdf and move this into an io module.
import io
import numpy
import zipfile # must diy to append to .npz
from pathlib import Path
with zipfile.ZipFile(output, 'w') as zf:
for isam in samples:
sam = ds[isam]
for iten, ten in enumerate(sam):
bio = io.BytesIO()
numpy.save(bio, ten.cpu().detach().numpy())
zf.writestr(f'sample_{isam}_{iten}.npy', data=bio.getbuffer().tobytes())


@cli.command('plot3p1')
@click.option("-o", "--output", default='samples.png',
help="Output in which to save the extracted samples") # fixme: support also hdf
@click.option("-s", "--sample", multiple=True, type=str,
help="Index or comma separated list of indices for samples to extract")
@click.argument("datapaths", nargs=-1)
@click.pass_context
def plot4dnnroi(ctx, output, sample, datapaths):
'''
Plot 3 layers from first tensor and 1 image from second from each sample.
'''

samples = list(map(int,listify(*sample, delim=",")))

log.info(config)
from wirecell.dnn.apps import dnnroi
# fixme: make choice of dataset optional
from wirecell.dnn.apps import dnnroi as app
ds = app.Dataset(datapaths)

# fixme: move plotting into a dnn.plots module
import matplotlib.pyplot as plt
from wirecell.util.plottools import pages
with pages(output, single=len(samples)==1) as out:

for idx in samples:
rec,tru = ds[idx]
rec = rec.detach().numpy()
tru = tru.detach().numpy()
fig,axes = plt.subplots(2,2)
axes[0][0].imshow(rec[0])
axes[0][1].imshow(rec[1])
axes[1][0].imshow(rec[2])
axes[1][1].imshow(tru[0])

out.savefig()


@cli.command("vizmod")
Expand All @@ -35,10 +129,14 @@ def train_dnnroi(ctx, config, files):
@click.option("-c","--channels", default=3, help="Number of input image channels")
@click.option("-C","--classes", default=6, help="Number of output classes")
@click.option("-b","--batch", default=1, help="Number of batch images")
@click.option("--skips", default=4, help="Number skip layers")
@click.option("--padding/--no-padding", default=False, is_flag=True, help="Use padding")
@click.option("--bilinear/--no-bilinear", default=False, is_flag=True, help="Use bilinear upsampling")
@click.option("--batchnorm/--no-batchnorm", default=False, is_flag=True, help="Use batch normalization")
@click.option("-o","--output", default=None, help="File name to fill with GraphViz dot")
@click.option("-m","--model", default="UNet",
type=click.Choice(["UNet","UsuyamaUNet", "MilesialUNet","list"]))
def vizmod(shape, channels, classes, batch, output, model):
def vizmod(shape, channels, classes, batch, skips, padding, bilinear, batchnorm, output, model):
'''
Produce a text summary and if -o/--output given also a GraphViz diagram of a named model.
'''
Expand All @@ -58,20 +156,23 @@ def vizmod(shape, channels, classes, batch, output, model):

Mod = getattr(models, model)

mod = Mod(channels, classes, imshape)
print(f'{channels=} {classes=} {imshape=} {skips=} {batchnorm=} {bilinear=} {padding=}')

mod = Mod(channels, classes, imshape, nskips=skips,
batch_norm=batchnorm, bilinear=bilinear, padding=padding)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
mod = mod.to(device)

from torchsummary import summary

full_shape = (channels, imshape[0], imshape[0])
full_shape = (channels, imshape[0], imshape[1])
summary(mod, input_size=full_shape, device=device)

if output:
from torchview import draw_graph
batch_shape = (batch, channels, imshape[0], imshape[0])
gr = draw_graph(mod, input_size=bach_shape, device=device)
batch_shape = (batch, channels, imshape[0], imshape[1])
gr = draw_graph(mod, input_size=batch_shape, device=device)
with open(output, "w") as fp:
fp.write(str(gr.visual_graph))

Expand Down
4 changes: 4 additions & 0 deletions wirecell/dnn/apps/dnnroi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env python
from .train import Classifier as Trainer
from .data import Dataset
from .model import Network
78 changes: 78 additions & 0 deletions wirecell/dnn/apps/dnnroi/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python
'''
Dataset and model specific to DNNROI training.
See for example:
https://www.phy.bnl.gov/~hyu/dunefd/dnn-roi-pdvd/Pytorch-UNet/data/
'''
# fixme: add support to load from URL

from wirecell.dnn.data import hdf

from .transforms import Rec as Rect, Tru as Trut, Params as TrParams


class Rec(hdf.Single):
'''
A DNNROI "rec" dataset.
This consists of conventional sigproc results produced by WCT's
OmnibusSigProc in HDF5 "frame file" form.
'''

file_re = r'.*g4-rec-r(\d+)\.h5'

path_res = tuple(
r'/(\d+)/%s\d'%tag for tag in [
'frame_loose_lf', 'frame_mp2_roi', 'frame_mp3_roi']
)

def __init__(self, paths,
file_re=None, path_res=None,
trparams: TrParams = Trut.default_params, cache=False):

dom = hdf.Domain(hdf.ReMatcher(file_re or self.file_re,
path_res or self.path_res),
transform=Rect(trparams),
cache=cache, grad=True,
name="dnnroirec")
super().__init__(dom, paths)


class Tru(hdf.Single):
'''
A DNNROI "tru" dataset.
This consists of the target ROI
'''

file_re = r'.*g4-tru-r(\d+)\.h5'

path_res = tuple(
r'/(\d+)/%s\d'%tag for tag in ['frame_ductor']
)

def __init__(self, paths, threshold = 0.5,
file_re=None, path_res=None,
trparams: TrParams = Trut.default_params, cache=False):

dom = hdf.Domain(hdf.ReMatcher(file_re or self.file_re,
path_res or self.path_res),
transform=Trut(trparams, threshold),
cache=cache, grad=False,
name="dnnroitru")

super().__init__(dom, paths)


class Dataset(hdf.Multi):
'''
The full DNNROI dataset is effectively zip(Rec,Tru).
'''
def __init__(self, paths, threshold=0.5, cache=False):
# fixme: allow configuring the transforms.
super().__init__(Rec(paths, cache=cache),
Tru(paths, threshold, cache=cache))


15 changes: 15 additions & 0 deletions wirecell/dnn/apps/dnnroi/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import torch.nn as nn
from wirecell.dnn.models.unet import UNet

class Network(nn.Module):

def __init__(self, shape, n_channels=3, n_classes=1, batch_size=1):
super().__init__()
self.unet = UNet(n_channels, n_classes, shape, batch_size,
batch_norm=True, bilinear=True)

def forward(self, x):
x = self.unet(x)
return torch.sigmoid(x)

9 changes: 9 additions & 0 deletions wirecell/dnn/apps/dnnroi/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from wirecell.dnn.train import Classifier as Base


class Classifier(Base):
'''
The DNNROI classifier
'''
def __init__(self, net, lr=0.1, momentum=0.9, weight_decay=0.0005, **optkwds):
super().__init__(net, lr=lr, momentum=momentum, weight_decay=weight_decay, **optkwds)
116 changes: 116 additions & 0 deletions wirecell/dnn/apps/dnnroi/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python
'''
The dataset transforms relevant to DNNROI
'''

from dataclasses import dataclass
from typing import Type, Tuple


@dataclass
class DimParams:
'''
Per-dimension parameters for rec and tru dataset transforms.
- crop :: a half-open range as slice
- rebin :: an integer downsampling factor
FYI, common channel counts per wire plane:
- DUNE-VD: 256+320+288
- DUNE-HD: 800+800+960
- Uboone: 2400+2400+3456
- SBND-v1: 1986+1986+1666
- SBND-v0200: 1984+1984+1664
- ICARUS: 1056+5600+5600
'''
crop: slice
rebin: int = 1

def __post_init__(self):
if not isinstance(self.crop, slice):
self.crop = slice(*self.crop)


@dataclass
class Params:
'''
Common parameters for rec and tru dataset transforms.
elech is for electronics channel dimension
ticks is for sampling period dimension
values are divided by norm
'''
elech: Type[DimParams]
ticks: Type[DimParams]
norm: float = 1.0


class Rec:
'''
The DNNROI "rec" data transformation.
'''

default_params = Params(DimParams((476, 952), 1), DimParams((0,6000), 10), 4000)

def __init__(self, params: Params = None):
'''
Arguments:
- params :: a Params
'''
self.params = params or self.default_params

def crop(self, x):
return x[:, self.params.elech.crop, self.params.ticks.crop]

def rebin(self, x):
ne, nt = self.params.elech.rebin, self.params.ticks.rebin,
sh = (x.shape[0], # 0
x.shape[1] // ne, # 1
ne, # 2
x.shape[2] // nt, # 3
nt) # 4
return x.reshape(sh).mean(4).mean(2) # (imgch, elech_rebinned, ticks_rebinned)

def transform(self, x):
x = self.crop(x)
x = self.rebin(x)
x = x/self.params.norm
return x


def __call__(self, x):
'''
Input and output are shaped:
(# of image channels/layers, # electronic channels, # of time samples)
Last two dimensions of output are rebinned.
'''
return self.transform(x)


class Tru(Rec):
'''
The DNNROI "tru" data transformation.
This is same as "rec" but with a thresholding.
'''

default_params = Params(DimParams((476, 952), 1), DimParams((0,6000), 10), 200)

def __init__(self, params: Params = None, threshold: float = 0.5):
'''
Arguments (see Rec for more):
- threshold :: threshold for array values to be set to 0 or 1.
'''
super().__init__(params or self.default_params)
self.threshold = threshold

def __call__(self, x):
x = self.transform(x)
return (x > self.threshold).to(float)


Loading

0 comments on commit 5898771

Please sign in to comment.