From 5898771e0c373593c7a998f838e8f434f74ab93d Mon Sep 17 00:00:00 2001 From: Brett Viren Date: Mon, 4 Nov 2024 16:17:09 -0500 Subject: [PATCH] Fix and make more flexible the constructing of U-Net and DNNROI nets. And start adding trainer and tracker. --- wirecell/dnn/__main__.py | 123 ++++++++- wirecell/dnn/apps/dnnroi/__init__.py | 4 + wirecell/dnn/apps/dnnroi/data.py | 78 ++++++ wirecell/dnn/apps/dnnroi/model.py | 15 ++ wirecell/dnn/apps/dnnroi/train.py | 9 + wirecell/dnn/apps/dnnroi/transforms.py | 116 +++++++++ wirecell/dnn/docs/dnnroi.org | 72 ++++++ wirecell/dnn/docs/unet.dot | 20 ++ wirecell/dnn/models/unet.py | 339 +++++++++---------------- wirecell/dnn/tracker.py | 179 +++++++++++++ wirecell/dnn/train.py | 10 +- wirecell/util/plottools.py | 25 +- 12 files changed, 759 insertions(+), 231 deletions(-) create mode 100644 wirecell/dnn/apps/dnnroi/__init__.py create mode 100644 wirecell/dnn/apps/dnnroi/data.py create mode 100644 wirecell/dnn/apps/dnnroi/model.py create mode 100644 wirecell/dnn/apps/dnnroi/train.py create mode 100644 wirecell/dnn/apps/dnnroi/transforms.py create mode 100644 wirecell/dnn/docs/unet.dot create mode 100644 wirecell/dnn/tracker.py diff --git a/wirecell/dnn/__main__.py b/wirecell/dnn/__main__.py index 3a9f7d1..39d9cda 100644 --- a/wirecell/dnn/__main__.py +++ b/wirecell/dnn/__main__.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 import click + from wirecell.util.cli import context, log, jsonnet_loader -from wirecell.util.paths import unglob +from wirecell.util.paths import unglob, listify + @context("dnn") def cli(ctx): @@ -15,18 +17,110 @@ def cli(ctx): @click.option("-c", "--config", type=click.Path(), help="Set configuration file") +@click.option("-e", "--epochs", default=1, help="Number of epochs over which to train") +@click.option("-b", "--batch", default=1, help="Batch size") @click.argument("files", nargs=-1) @click.pass_context -def train_dnnroi(ctx, config, files): +def train(ctx, config, epochs, batch, files): ''' Train the DNNROI model. ''' - fpaths = unglob(files) - print (fpaths) + # fixme: args to explicitly select use of "flow" tracking. + from wirecell.dnn.tracker import flow + + # fixme: make choice of dataset optional + from wirecell.dnn.apps import dnnroi as app + + + + # fixme: this should all be moved under the app + ds = app.Dataset(unglob(files)) + imshape = ds[0][0].shape[-2:] + print(f'{imshape=}') + + from torch.utils.data import DataLoader + dl = DataLoader(ds, batch_size=batch, shuffle=True) + + net = app.Network(imshape, batch_size=batch) + trainer = app.Trainer(net, tracker=flow) + + for epoch in range(epochs): + loss = trainer.epoch(dl) + flow.log_metric("epoch_loss", dict(epoch=epoch, loss=loss)) + + # log.info(config) + + + +@cli.command('extract') +@click.option("-o", "--output", default='samples.npz', + help="Output in which to save the extracted samples") # fixme: support also hdf +@click.option("-s", "--sample", multiple=True, type=str, + help="Index or comma separated list of indices for samples to extract") +@click.argument("datapaths", nargs=-1) +@click.pass_context +def extract(ctx, output, sample, datapaths): + ''' + Extract samples from a dataset. + + The datapaths name files or file globs. + ''' + samples = map(int,listify(*sample, delim=",")) + + # fixme: make choice of dataset optional + from wirecell.dnn.apps import dnnroi as app + ds = app.Dataset(datapaths) + + print(f'dataset has {len(ds)} entries from {len(datapaths)} data paths') + + # fixme: support npz and hdf and move this into an io module. + import io + import numpy + import zipfile # must diy to append to .npz + from pathlib import Path + with zipfile.ZipFile(output, 'w') as zf: + for isam in samples: + sam = ds[isam] + for iten, ten in enumerate(sam): + bio = io.BytesIO() + numpy.save(bio, ten.cpu().detach().numpy()) + zf.writestr(f'sample_{isam}_{iten}.npy', data=bio.getbuffer().tobytes()) + + +@cli.command('plot3p1') +@click.option("-o", "--output", default='samples.png', + help="Output in which to save the extracted samples") # fixme: support also hdf +@click.option("-s", "--sample", multiple=True, type=str, + help="Index or comma separated list of indices for samples to extract") +@click.argument("datapaths", nargs=-1) +@click.pass_context +def plot4dnnroi(ctx, output, sample, datapaths): + ''' + Plot 3 layers from first tensor and 1 image from second from each sample. + ''' + + samples = list(map(int,listify(*sample, delim=","))) - log.info(config) - from wirecell.dnn.apps import dnnroi + # fixme: make choice of dataset optional + from wirecell.dnn.apps import dnnroi as app + ds = app.Dataset(datapaths) + # fixme: move plotting into a dnn.plots module + import matplotlib.pyplot as plt + from wirecell.util.plottools import pages + with pages(output, single=len(samples)==1) as out: + + for idx in samples: + rec,tru = ds[idx] + rec = rec.detach().numpy() + tru = tru.detach().numpy() + fig,axes = plt.subplots(2,2) + axes[0][0].imshow(rec[0]) + axes[0][1].imshow(rec[1]) + axes[1][0].imshow(rec[2]) + axes[1][1].imshow(tru[0]) + + out.savefig() @cli.command("vizmod") @@ -35,10 +129,14 @@ def train_dnnroi(ctx, config, files): @click.option("-c","--channels", default=3, help="Number of input image channels") @click.option("-C","--classes", default=6, help="Number of output classes") @click.option("-b","--batch", default=1, help="Number of batch images") +@click.option("--skips", default=4, help="Number skip layers") +@click.option("--padding/--no-padding", default=False, is_flag=True, help="Use padding") +@click.option("--bilinear/--no-bilinear", default=False, is_flag=True, help="Use bilinear upsampling") +@click.option("--batchnorm/--no-batchnorm", default=False, is_flag=True, help="Use batch normalization") @click.option("-o","--output", default=None, help="File name to fill with GraphViz dot") @click.option("-m","--model", default="UNet", type=click.Choice(["UNet","UsuyamaUNet", "MilesialUNet","list"])) -def vizmod(shape, channels, classes, batch, output, model): +def vizmod(shape, channels, classes, batch, skips, padding, bilinear, batchnorm, output, model): ''' Produce a text summary and if -o/--output given also a GraphViz diagram of a named model. ''' @@ -58,20 +156,23 @@ def vizmod(shape, channels, classes, batch, output, model): Mod = getattr(models, model) - mod = Mod(channels, classes, imshape) + print(f'{channels=} {classes=} {imshape=} {skips=} {batchnorm=} {bilinear=} {padding=}') + + mod = Mod(channels, classes, imshape, nskips=skips, + batch_norm=batchnorm, bilinear=bilinear, padding=padding) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = 'cpu' mod = mod.to(device) from torchsummary import summary - full_shape = (channels, imshape[0], imshape[0]) + full_shape = (channels, imshape[0], imshape[1]) summary(mod, input_size=full_shape, device=device) if output: from torchview import draw_graph - batch_shape = (batch, channels, imshape[0], imshape[0]) - gr = draw_graph(mod, input_size=bach_shape, device=device) + batch_shape = (batch, channels, imshape[0], imshape[1]) + gr = draw_graph(mod, input_size=batch_shape, device=device) with open(output, "w") as fp: fp.write(str(gr.visual_graph)) diff --git a/wirecell/dnn/apps/dnnroi/__init__.py b/wirecell/dnn/apps/dnnroi/__init__.py new file mode 100644 index 0000000..f5f1ea2 --- /dev/null +++ b/wirecell/dnn/apps/dnnroi/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +from .train import Classifier as Trainer +from .data import Dataset +from .model import Network diff --git a/wirecell/dnn/apps/dnnroi/data.py b/wirecell/dnn/apps/dnnroi/data.py new file mode 100644 index 0000000..3c948ec --- /dev/null +++ b/wirecell/dnn/apps/dnnroi/data.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +''' +Dataset and model specific to DNNROI training. + +See for example: + +https://www.phy.bnl.gov/~hyu/dunefd/dnn-roi-pdvd/Pytorch-UNet/data/ +''' +# fixme: add support to load from URL + +from wirecell.dnn.data import hdf + +from .transforms import Rec as Rect, Tru as Trut, Params as TrParams + + +class Rec(hdf.Single): + ''' + A DNNROI "rec" dataset. + + This consists of conventional sigproc results produced by WCT's + OmnibusSigProc in HDF5 "frame file" form. + ''' + + file_re = r'.*g4-rec-r(\d+)\.h5' + + path_res = tuple( + r'/(\d+)/%s\d'%tag for tag in [ + 'frame_loose_lf', 'frame_mp2_roi', 'frame_mp3_roi'] + ) + + def __init__(self, paths, + file_re=None, path_res=None, + trparams: TrParams = Trut.default_params, cache=False): + + dom = hdf.Domain(hdf.ReMatcher(file_re or self.file_re, + path_res or self.path_res), + transform=Rect(trparams), + cache=cache, grad=True, + name="dnnroirec") + super().__init__(dom, paths) + + +class Tru(hdf.Single): + ''' + A DNNROI "tru" dataset. + + This consists of the target ROI + ''' + + file_re = r'.*g4-tru-r(\d+)\.h5' + + path_res = tuple( + r'/(\d+)/%s\d'%tag for tag in ['frame_ductor'] + ) + + def __init__(self, paths, threshold = 0.5, + file_re=None, path_res=None, + trparams: TrParams = Trut.default_params, cache=False): + + dom = hdf.Domain(hdf.ReMatcher(file_re or self.file_re, + path_res or self.path_res), + transform=Trut(trparams, threshold), + cache=cache, grad=False, + name="dnnroitru") + + super().__init__(dom, paths) + + +class Dataset(hdf.Multi): + ''' + The full DNNROI dataset is effectively zip(Rec,Tru). + ''' + def __init__(self, paths, threshold=0.5, cache=False): + # fixme: allow configuring the transforms. + super().__init__(Rec(paths, cache=cache), + Tru(paths, threshold, cache=cache)) + + diff --git a/wirecell/dnn/apps/dnnroi/model.py b/wirecell/dnn/apps/dnnroi/model.py new file mode 100644 index 0000000..a68bfe7 --- /dev/null +++ b/wirecell/dnn/apps/dnnroi/model.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from wirecell.dnn.models.unet import UNet + +class Network(nn.Module): + + def __init__(self, shape, n_channels=3, n_classes=1, batch_size=1): + super().__init__() + self.unet = UNet(n_channels, n_classes, shape, batch_size, + batch_norm=True, bilinear=True) + + def forward(self, x): + x = self.unet(x) + return torch.sigmoid(x) + diff --git a/wirecell/dnn/apps/dnnroi/train.py b/wirecell/dnn/apps/dnnroi/train.py new file mode 100644 index 0000000..1214b98 --- /dev/null +++ b/wirecell/dnn/apps/dnnroi/train.py @@ -0,0 +1,9 @@ +from wirecell.dnn.train import Classifier as Base + + +class Classifier(Base): + ''' + The DNNROI classifier + ''' + def __init__(self, net, lr=0.1, momentum=0.9, weight_decay=0.0005, **optkwds): + super().__init__(net, lr=lr, momentum=momentum, weight_decay=weight_decay, **optkwds) diff --git a/wirecell/dnn/apps/dnnroi/transforms.py b/wirecell/dnn/apps/dnnroi/transforms.py new file mode 100644 index 0000000..f575ef5 --- /dev/null +++ b/wirecell/dnn/apps/dnnroi/transforms.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +''' +The dataset transforms relevant to DNNROI +''' + +from dataclasses import dataclass +from typing import Type, Tuple + + +@dataclass +class DimParams: + ''' + Per-dimension parameters for rec and tru dataset transforms. + + - crop :: a half-open range as slice + - rebin :: an integer downsampling factor + + FYI, common channel counts per wire plane: + - DUNE-VD: 256+320+288 + - DUNE-HD: 800+800+960 + - Uboone: 2400+2400+3456 + - SBND-v1: 1986+1986+1666 + - SBND-v0200: 1984+1984+1664 + - ICARUS: 1056+5600+5600 + ''' + crop: slice + rebin: int = 1 + + def __post_init__(self): + if not isinstance(self.crop, slice): + self.crop = slice(*self.crop) + + +@dataclass +class Params: + ''' + Common parameters for rec and tru dataset transforms. + + elech is for electronics channel dimension + ticks is for sampling period dimension + values are divided by norm + ''' + elech: Type[DimParams] + ticks: Type[DimParams] + norm: float = 1.0 + + +class Rec: + ''' + The DNNROI "rec" data transformation. + ''' + + default_params = Params(DimParams((476, 952), 1), DimParams((0,6000), 10), 4000) + + def __init__(self, params: Params = None): + ''' + Arguments: + + - params :: a Params + + ''' + self.params = params or self.default_params + + def crop(self, x): + return x[:, self.params.elech.crop, self.params.ticks.crop] + + def rebin(self, x): + ne, nt = self.params.elech.rebin, self.params.ticks.rebin, + sh = (x.shape[0], # 0 + x.shape[1] // ne, # 1 + ne, # 2 + x.shape[2] // nt, # 3 + nt) # 4 + return x.reshape(sh).mean(4).mean(2) # (imgch, elech_rebinned, ticks_rebinned) + + def transform(self, x): + x = self.crop(x) + x = self.rebin(x) + x = x/self.params.norm + return x + + + def __call__(self, x): + ''' + Input and output are shaped: + + (# of image channels/layers, # electronic channels, # of time samples) + + Last two dimensions of output are rebinned. + ''' + return self.transform(x) + + +class Tru(Rec): + ''' + The DNNROI "tru" data transformation. + + This is same as "rec" but with a thresholding. + ''' + + default_params = Params(DimParams((476, 952), 1), DimParams((0,6000), 10), 200) + + def __init__(self, params: Params = None, threshold: float = 0.5): + ''' + Arguments (see Rec for more): + + - threshold :: threshold for array values to be set to 0 or 1. + ''' + super().__init__(params or self.default_params) + self.threshold = threshold + + def __call__(self, x): + x = self.transform(x) + return (x > self.threshold).to(float) + + diff --git a/wirecell/dnn/docs/dnnroi.org b/wirecell/dnn/docs/dnnroi.org index 4ea778e..3464945 100644 --- a/wirecell/dnn/docs/dnnroi.org +++ b/wirecell/dnn/docs/dnnroi.org @@ -59,3 +59,75 @@ Every "tru" sample has additional processing: - crop as for "rec" - threshold to set value 0 or 1. +* Network + +The DNNROI network is examined and compared to U-Net. + +** Dimensions + +The dimensions of output tensors from major units of the DNNROI network: + + + | dir | level | unit | ch | hpx | wpx | + |--------+-------+----------+------+-----+-----| + | in | | input | 3 | 476 | 600 | + | down | 0 | dconv | 64 | 476 | 600 | + | down | 0 | pool | 64 | 238 | 300 | + | down | 1 | dconv | 128 | 238 | 300 | + | down | 1 | pool | 128 | 119 | 150 | + | down | 2 | dconv | 256 | 119 | 150 | + | down | 2 | pool | 256 | 59 | 75 | + | down | 3 | dconv | 512 | 59 | 75 | + | down | 3 | pool | 512 | 29 | 37 | + | bottom | 4 | dconv | 512 | 29 | 37 | + | up | 3 | upsample | 512 | 58 | 74 | + | up | 3 | pad | 512 | 59 | 75 | + | up | 3 | cat | 1024 | 59 | 75 | + | up | 3 | dconv | 256 | 59 | 75 | + | up | 2 | upsample | 256 | 118 | 150 | + | up | 2 | pad | 256 | 119 | 150 | + | up | 2 | cat | 512 | 119 | 150 | + | up | 2 | dconv | 128 | 119 | 150 | + | up | 1 | upsample | 128 | 238 | 300 | + | up | 1 | pad | 128 | 238 | 300 | + | up | 1 | cat | 256 | 238 | 300 | + | up | 1 | dconv | 64 | 238 | 300 | + | up | 0 | upsample | 64 | 476 | 600 | + | up | 0 | pad | 64 | 476 | 600 | + | up | 0 | cat | 128 | 476 | 600 | + | up | 0 | dconv | 64 | 476 | 600 | + | out | | conv | 1 | 476 | 600 | + | out | | sigmoid | 1 | 476 | 600 | + + +** Deviations from U-Net + +The DNNROI network architecture takes inspiration from U-Net but deviates in many details from what is described in the U-Net paper. In general, U-Net is more regular in image channel dimension while DNNROI is more regular in image pixel dimensions. To achieve this and other goals, DNNROI inserts and swaps certain operations. The following is a summary of the differences. + +The main unit that makes up the "U" shape is the "double convolution" (dconv). +It is this general unit that is most modified in DNNROI from U-Net. These +modifications vary depending on whether the dconv is in downward, bottom or +upward legs of the "U". + +- DNNROI inserts a batch normalization between each pair of 2D convolution and ReLU. + +- The U-net dconv reduces pixel dimension by 4 while DNNROI zero-pads after each + 2D convlution and the pixel dimension sizes remains unchanged. + +- After the initial inflation to size 64, both U-Net and DNNROI dconv on the + downward leg doubles the channel dimension. U-Net dconv on the "bottom" + of the "U" also doubles this dimension while DNNROI does not. + +- On the upward leg, U-Net dconv uniformly halves the channel dimension size. + DNNROI quarters this dimension for the first three upward steps and halves it + for the final step. + +Changes in other units: + +- DNNROI uses bilinear interpolated upsampling while U-Net uses ~ConvTranspose2d~. + +- As a consequence, DNNROI zero-pads the pixel dimensions after upsampling when an odd target size is required. The image size of 572 pixels used in the U-Net paper avoids encountering odd upsampling target sizes (relying on the 4 pixel loss in U-Net dconv). + +- DNNROI applies a final sigmoid to U-Net's output segmentation map. + +- The U-Net skip transfers a core crop of the array whereas DNNROI preserves the entire array across the skip connection. This is enabled by DNNROI applying padding in dconv units. As a consequence, U-Net skip operation depends on the "natural" sizes of the upward U leg while DNNROI's upward U leg sizes depends on the "natural" sizes of the downward U leg. diff --git a/wirecell/dnn/docs/unet.dot b/wirecell/dnn/docs/unet.dot new file mode 100644 index 0000000..53558a2 --- /dev/null +++ b/wirecell/dnn/docs/unet.dot @@ -0,0 +1,20 @@ +digraph unet { + + {rank=same in osm out} + {rank=same ddc0 udc0 } + {rank=same ddc1 udc1 } + {rank=same ddc2 udc2 } + {rank=same ddc3 udc3 } + {rank=same d0 u0} + {rank=same d1 u1} + {rank=same d2 u2} + {rank=same d3 u3} + + in->ddc0->d0->ddc1->d1->ddc2->d2->ddc3->d3->bdc + bdc->u3->udc3->u2->udc2->u1->udc1->u0->udc0->osm + ddc0->u0 + ddc1->u1 + ddc2->u2 + ddc3->u3 + osm->out +} diff --git a/wirecell/dnn/models/unet.py b/wirecell/dnn/models/unet.py index b9aa2ed..12de12e 100644 --- a/wirecell/dnn/models/unet.py +++ b/wirecell/dnn/models/unet.py @@ -5,9 +5,6 @@ https://arxiv.org/abs/1505.04597 https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png -Unlike several other implementations using the name "U-Net" that one runs -across, this one tries to exactly replicate what is in the paper, by default. - The following labels are used to identify units of the network and refers to the u-net-architecture.png figure. @@ -24,134 +21,27 @@ - usamp :: "up sampling" (green arrow), the "up-conv 2x2" that input from dconv result and output to umerge. -- skip :: "skip connection" (gray arrow), the center crop of dconv output and - provides input to umerge. A "skip level" counts the skips from 0 starting at - the top of the U. The "bottom" can be considered a skip level for purposes of - calculating the output size of its dconv. +- skip :: "skip connection" (gray arrow), this simply shunts the output from a + dconv on the downward leg to one input of a umerge. - umerge :: "up merge" (gray+green arrows), concatenation of the skip result - with the usamp result and provides input to an dconv on the upward leg. + with the up samping result and provides input to an dconv on the upward leg. The default configuration produces U-Net. The following optional extensions, off by default, are supported: -- insert two BatchNorm2d in double convolution unit (dconv). -- use other than 4 skip connection levels. -- use non-square images data. +- batch_norm=True :: insert two BatchNorm2d in double convolution unit (dconv). +- bilinear=True :: use bilinear interpolation instead of ConvTranspose2d in up-conv +- padding=True :: zero-pad in dconv so image input size is retained and in umerge is needed to match arrays from skip and below connections. +- nskips=N :: use a different number of skip connection levels besides 4. +- use non-square images. ''' import torch import torch.nn as nn -from torch.nn.functional import grid_sample - - -def down_out_size(orig_size, skip_level): - ''' - Return the output size from a down unit at a given skip level. - - Skip level counts from 0 transfer "over" the U via skip connection or bottom. - - ''' - size = orig_size - for dskip in range(skip_level + 1): - if dskip: - size = size // 2 - size = size - 4 - return size - -def up_in_size(orig_size, skip_level, nlevels = 4): - ''' - Return the input size to an up unit (output of a skip) at a given skip level. - - The nlevels counts the number of skip connections across the U. - ''' - size = down_out_size(orig_size, nlevels) - for uskip in range( nlevels - skip_level): - if uskip: - size = size - 4 - size = size * 2 - return size - - -def dimension(in_channels = 1, in_size = 572, nskips = 4): - ''' - Calculate 1D image channel and pixel dimensions for elements of the U-Net. - - - size :: the size of both input image dimensions (572 for U-Net paper). - - - nskips :: the number of skip connections (4 for U-Net paper) - - This returns four lists of size 2*nskips+1. Each element of a list - corresponds to one major "dconv" unit as we go along the U: nskips "down", - one "bottom" and nskips "up". The lists are: - - - number of input channels - - number of output channels - - input size - - output size - - The [nskips] element refers to the bottom dconv. - - See skip_dimensions() to form similar lists from the output of this function - for the skip connections. - - Note, the output segmentation map is excluded. The final element in the - lists refers to the top up dconv. - ''' - chans_down_in = [in_channels] + [2**(6+n) for n in range(nskips)] # includes bottom - chans_down_out = [2**(6+n) for n in range(nskips+1)] - chans_up_in = list(chans_down_out[1:]) - chans_up_in.reverse() - chans_in = chans_down_in + chans_up_in - chans_up_out = chans_down_in[1:] - chans_up_out.reverse() - chans_out = chans_down_out + chans_up_out - - size_in = [in_size] - size_out = [] - for skip in range(nskips): - siz = size_in[-1] - 4 # dconv reduction - size_out.append(siz) - size_in.append(siz // 2) # max pool reduction - size_out.append(size_in[-1] - 4) # bottom out - for rskip in range(nskips): - size_in.append(size_out[-1] * 2) # up conv - size_out.append(size_in[-1] - 4) # dconv reduction - - return (chans_in, chans_out, size_in, size_out) - - -def dimensions(in_channels = 1, in_shape = (572,572), nskips = 4): - ''' - N-D version of dimension() where sizes are shapes. - ''' - dims = [dimension(in_channels, in_size, nskips) for in_size in in_shape] - in_chans = dims[0][0] - out_chans = dims[0][1] - in_shapes = tuple(zip(*[d[2] for d in dims])) - out_shapes = tuple(zip(*[d[2] for d in dims])) - return in_chans, out_chans, in_shapes, out_shapes - - -def skip_dimensions(dims): - ''' - Reformat the output of dimensions() to the same form but for the skip - connections in order of skip level. - ''' - (chans_in, chans_out, shape_in, shape_out) = dims - - nskips = (len(chans_in)-1)//2 - - schans_in = chans_out[:nskips] - schans_out = schans_in # skips preserve channel dim - - sshape_in = shape_out[:nskips] - sshape_out = list(shape_in[-nskips:]) - sshape_out.reverse() - sshape_out = tuple(sshape_out) - return (schans_in, schans_out, sshape_in, sshape_out) +from torch.nn.functional import pad as nnpad def dconv(in_channels, out_channels, kernel_size = 3, padding = 0, @@ -167,8 +57,8 @@ def dconv(in_channels, out_channels, kernel_size = 3, padding = 0, ] if batch_norm: - parts.insert(3, nn.BatchNorm2d(out_ch)) - parts.insert(1, nn.BatchNorm2d(out_ch)) + parts.insert(3, nn.BatchNorm2d(out_channels)) + parts.insert(1, nn.BatchNorm2d(out_channels)) return nn.Sequential(*parts) @@ -180,138 +70,155 @@ def dsamp(): return nn.MaxPool2d(2) -def usamp(in_ch): - ''' - The "up sampling". - ''' - return nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) - # return nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - - -# def build_grid(source_size, target_size, batch_size = 1): -# ''' -# Map output pixels to input pixels for cropping by grid_sample(). - -# This assumes square images of given size. -# ''' -# # simplified version of what is given in -# # https://discuss.pytorch.org/t/cropping-a-minibatch-of-images-each-image-a-bit-differently/12247 -# k = float(target_size)/float(source_size) -# direct = torch.linspace(-k,k,target_size).unsqueeze(0).repeat(target_size,1).unsqueeze(-1) -# grid = torch.cat([direct,direct.transpose(1,0)],dim=2).unsqueeze(0) -# return grid.repeat(batch_size, 1, 1, 1) - -class skip(nn.Module): - ''' - The "skip connection" providing a core cropping. +class umerge(nn.Module): ''' - def __init__(self, source_shape, target_shape, batch_size=1): - super().__init__() - self.crop = [] - for ssize, tsize in zip(source_shape, target_shape): - margin = (ssize - tsize)//2 - self.crop.append (slice(margin, margin+tsize)) + The "upsample merge" of the outputs from a skip and a dconv. - # A fancier way to do it which, but why? - # self.register_buffer('g', build_grid(source_size, target_size, batch_size)) - # grid should have shape: (nbatch, nrows, ncols, 2) + The "up" array is upsampled and then appended to the "over" array. - def forward(self, x): - # x must be (nbatch, nchannel, nrows, ncols) - # print(f'grid: {self.g.shape} {self.g.dtype} {self.g.device}') - # print(f'data: {x.shape} {x.dtype} {x.device}') - # return grid_sample(x, self.g, align_corners=True, mode='nearest') - return x[:,:,self.crop[0],self.crop[1]] + Both options have large repercussion on upstream nodes: + If bilinear, the number of channels in the upsampled array is unchanged else + it is halved. -class umerge(nn.Module): + If padded, the upsampled array pixel dimensions will be padded to match + those of the "over" array. ''' - The "upsample merge" of the outputs from a skip and a dconv. - ''' - def __init__(self, nchannels): + def __init__(self, nchannels, bilinear=False, padding=False): ''' Give number of channels in the input to the upsampling port. ''' super().__init__() - self._nchannels = nchannels - self.upsamp = nn.ConvTranspose2d(nchannels, nchannels//2, 2, stride=2) + self.padding = padding + self.pads = None + self.slices = None + if bilinear: + self.upsamp = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.upsamp = nn.ConvTranspose2d(nchannels, nchannels//2, 2, stride=2) def forward(self, over, up): up = self.upsamp(up) + + if self.padding: + # when not cropping we must pad special to match when target is odd size + if self.pads is None: + pad = list() + for dim in [-1, -2]: + diff = over.shape[dim] - up.shape[dim] + half = diff // 2 + pad += [half, diff-half] + self.pads = tuple(pad) + up = nnpad(up, self.pads) + else: + if self.slices is None: + slices = [slice(None),] * len(up.shape) # select all by default + for dim in [-2,-1]: + hi = over.shape[dim] + lo = up.shape[dim] + if lo == hi: + continue + beg = (hi - lo) // 2 + end = beg + lo + slices[dim] = slice(beg,end) + print(f'{slices=}\n{over.shape=} {up.shape=}') + self.slices = tuple(slices) + over = over[self.slices] + print(f'{over.shape=}') + cat = torch.cat((over, up), dim=1) + print(f'{cat.shape=}') return cat +def make_dconv(ich, factor, padding=False, batch_norm=False): + n_padding = 1 if padding else 0 + och = ich + if ich < 64: + och = 64 # special first case + elif factor != 1: + och = int(ich*factor) + print(f'dconv {ich=} {och=} {factor=} {padding=} {batch_norm=}') + node = dconv(ich, och, padding=n_padding, batch_norm=batch_norm) + return node, och + +def make_dsamp(ich): + return dsamp(), ich + +def make_umerge(ich, bilinear=False, padding=False): + ''' + ich is number of channels from the skip + ''' + # Assume umerge halves the number of channels from the input below. + och = ich * 2 + return umerge(2*ich, bilinear=bilinear, padding=padding), och + + + class UNet(nn.Module): ''' U-Net model exactly as from the paper by default. ''' - def __init__(self, n_channels=3, n_classes=6, in_shape=(572,572), - batch_size=1, nskips=4, - batch_norm=False): + nskips=4, + batch_norm=False, bilinear=False, padding=False): super().__init__() - self.nskips = nskips - dims = dimensions(n_channels, in_shape, nskips) + nch = n_channels - # The major elements of the U - chans_in, chans_out, _, _ = dims + self.downleg = list() # nodes in downward U leg + skip_nchannels = list() # for making skips + for iskip in range(nskips): # go down the U making dconv and dsamp - # Note; we use setattr to make sure PyTorch summary finds the submodules. + dc_node, nch = make_dconv(nch, factor=2, padding=padding, batch_norm=batch_norm) + setattr(self, f'down_dconv_{iskip}', dc_node) + skip_nchannels.append(nch) - # The downward leg of the U. - for ind in range(nskips): - setattr(self, f'downleg_{ind}', dconv(chans_in[ind], chans_out[ind])) + ds_node, nch = make_dsamp(nch) + setattr(self, f'down_dsamp_{iskip}', ds_node) - # The bottom of the U - self.bottom = dconv(chans_in[nskips], chans_out[nskips]) + self.downleg.append((dc_node, ds_node)) - # The upward leg of the U. - for count, ind in enumerate(range(nskips+1, 2*nskips+1)): - setattr(self, f'upleg_{count}', dconv(chans_in[ind], chans_out[ind])) + factor = 1 if padding else 2 + self.bottom, nch = make_dconv(nch, factor=factor, padding=padding, batch_norm=batch_norm) - # The skip connections get applied top-down - schans_in, schans_out, ssize_in, ssize_out = skip_dimensions(dims) - for ind, ss in enumerate(zip(ssize_in, ssize_out)): - setattr(self, f'skip_{ind}', skip(*ss, batch_size=batch_size)) + # self.skips = list() + self.upleg = list() + for iskip in range(nskips-1, -1, -1): - # And the merges are applied bottom-up. - # We bake in the rule that upsample input has 2x the number of channels as the skip output. - umerges = [umerge(2*nc) for nc in schans_out] - umerges.reverse() - for ind, um in enumerate(umerges): - setattr(self, f'umerge_{ind}', um); + nch = skip_nchannels[iskip] + m_node, nch = make_umerge(nch, bilinear=bilinear, padding=padding) + setattr(self, f'up_umerge_{iskip}', m_node) - # Downsampler is data-independent and reused. - self.dsamp = dsamp() + factor = 0.25 if padding else 0.5 + dc_node, nch = make_dconv(nch, factor=factor, padding=padding, batch_norm=batch_norm) + setattr(self, f'up_dconv_{iskip}', dc_node) - self.segmap = nn.Conv2d(chans_out[-1], n_classes, 1) + self.upleg.append((m_node, dc_node)) # bottom up order + + self.segmap = nn.Conv2d(nch, n_classes, 1) - def getm(self, name, ind): - return getattr(self, f'{name}_{ind}') def forward(self, x): - - dskips = list() - - for ind in range(self.nskips): - dl = self.getm("downleg", ind) - dout = dl(x) - x = self.dsamp(dout) - sm = self.getm("skip", ind) - dskip = sm(dout) - dskips.append( dskip ) - + dump(x, "in") + + overs = list() + for skip, (dc,ds) in enumerate(self.downleg): + x = dc(x) + dump(x, f"down dc {skip}") + overs.append(x) + x = ds(x) + dump(x, f"down ds {skip}") + overs.reverse() x = self.bottom(x) - - dskips.reverse() # bottom-up - for ind in range(self.nskips): - s = dskips[ind] - um = self.getm("umerge", ind) - x = um(s, x) - ul = self.getm("upleg", ind) - x = ul(x) + dump(x, "bottom") + for over, (m,d) in zip(overs, self.upleg): + x = m(over, x) + dump(x, "up merge") + x = d(x) + dump(x, "up dc") x = self.segmap(x) return x +def dump(x, msg=""): + print(f'{x.shape} {msg}') diff --git a/wirecell/dnn/tracker.py b/wirecell/dnn/tracker.py new file mode 100644 index 0000000..1455259 --- /dev/null +++ b/wirecell/dnn/tracker.py @@ -0,0 +1,179 @@ +#+/usr/bin/env python +''' +Provide "machine learning experiment tracking". + +The "tracker" is an object that can be used to record input parameters and +intermediate and final values of training. + + +''' +from time import time +import shutil +from pathlib import Path +from hashlib import sha256 as hasher +import torch +class fsflow(): + ''' + Mimic a portion of mlflow tracking API using the filesystem as store. + + This is very limited w.r.t. mlflow. In particular is just provides basic + log_() methods and: + + - no autologging. + - no state. + - no entity return values. + + Logs: model, dataset, input example, signature, parameters and metrics + ''' + def __init__(self, basedir=None): + if basedir: + self.path = Path(basedir) + else: + self.path = Path(".") / "fsflow" + self.logpath = self.path / "fsflow-log.json" + + def log(self, thing): + ''' + Primitive sink of thing to log. Each entry is a line of text. + + Thing must be text or json-serializable. + ''' + if not isinstance(thing, str): + thing = json.dumps(thing) + with open(self.logpath, "a") as fp: + fp.write(thing + '\n') + + @property + def now(self): + return time() + + def set_tracking_uri(uri="flflow-log.json"): + ''' + Tracking URI is at best a log file name. + ''' + if uri.startswith("file://"): + uri = uri[7:] + if uri.startswith("//"): + uri[1:] + if uri.startswith("/"): + self.logpath = Path(uri) + else: + self.logpath = self.path / uri + + + def log_entry(self, kind, name, value): + ''' + Top level structured log entry. + ''' + dat = dict(t=self.now, kind=kind, name=name, value=value) + self.log(dat) + + def log_param(self, name, value): + self.log_entry("param", name, value) + + def log_params(self, params, **kwds): + params.update(kwds) + self.log_entry("params", "params", params) + + def log_metric(self, name, value): + self.log_entry("metric", name, value) + + def log_input(self, dataset, context=None, tags=None): + ''' + Dataset here means "tensor" + ''' + path = self.path + if context: + path = path / context + path = path / "input.npz" + path.parent.mkdir(parents=True, exists_ok=True) + numpy.savez_compressed(path, dataset) + self.log_entry("input", "path", str(path.absolute())) + + def log_artifact(self, local_path, artifact_path=None): + ''' + Copy local path to artifact path or default artifact directory. + ''' + local_path = Path(local_path) + if artifact_path: + artifact_path = self.path / "artifacts" / artifact_path + else: + artifact_path = self.path / "artifacts" / local_path.name + artifact_path.parent.mkdir(parents=True, exists_ok=True) + shutil.copy(local_path, artifact_path) + self.log_entry("artifact", local_path, artifact_path) + + def log_artifacts(self, local_dir, artifact_path=None): + ''' + Copy contents of local dir to artifact path or default artifact directory. + ''' + local_dir = Path(local_dir) + if artifact_path: + artifact_path = self.path / "artifacts" / artifact_path + else: + artifact_path = self.path / "artifacts" / local_path.name + artifact_path.parent.mkdir(parents=True, exists_ok=True) + shutil.copytree(local_dir, artifact_path) + self.log_entry("artifacts", local_dir, artifact_path) + + @property + def pytorch(self): + ''' + Mimic per-framework mlflow object attributes + ''' + return self + + def log_model(self, pytorch_model, artifact_path, **kwds): + if artifact_path: + artifact_path = self.path / "artifacts" / artifact_path + else: + artifact_path = self.path / "artifacts" / local_path.name + artifact_path.parent.mkdir(parents=True, exists_ok=True) + torch.save(pytorch_model.sate_dict(), artifact_path) + + self.log_entry("model", local_dir, artifact_path) + + + def create_experiment(self, name, artifact_location=None, tags=None): + ''' + Start an experiment and return a unique ID + ''' + # fixme: does the location need to be unique to the experiment? + if artifact_location: + artifact_location = Path(artifact_location) + else: + artifact_location = self.path / "artifacts" + artifact_location = artifact_location.absolute() + + tags = tags or dict() + + t = self.now + h = hasher() + h.update(str(t).encode()) + h.update(name.encode()) + h.update(artifact_location.encode()) + jtags = json.dumps(tags) + h.update(jtags.encode()) + eid = h.hexdigest() + + ent = dict(eid=eid, artifact_location=artifact_location, tags=tags) + self.log_entry("create_experiment", name, ent) + + return eid + + def set_experiment(self, experiment_name=None, experiment_id=None): + self.log_entry("set_experiment", experiment_id or experiment_name) + + def start_run(self, run_id=None, experiment_id=None, run_name=None, **kwds): + self.log_entry("start_run", run_id or run_name, + dict(kwds, run_id=run_id, experiment_id=experiment_id, run_name=run_name)) + def end_run(self, status = 'FINISHED'): + self.log_entry("end_run", 'status', status) + + + +try: + import mlflow + flow = mlflow +except ImportError: + flow = fsflow() diff --git a/wirecell/dnn/train.py b/wirecell/dnn/train.py index 58e179c..4682d6e 100644 --- a/wirecell/dnn/train.py +++ b/wirecell/dnn/train.py @@ -23,19 +23,23 @@ import torch.nn as nn class Classifier: - def __init__(self, net, optclass = optim.SGD, **optkwds): + def __init__(self, net, tracker=None, optclass = optim.SGD, **optkwds): self.net = net # model self.optimizer = optclass(net.parameters(), **optkwds) + self.tracker = tracker or Tracker() def epoch(self, data, criterion=nn.BCELoss()): ''' - One train over the batches of the data, return list of losses at each batch. + Train over the batches of the data, return list of losses at each batch. ''' epoch_losses = list() for features, labels in data: + self.optimizer.zero_grad() - prediction = self.net(src) + prediction = self.net(features) + print(f'{features.shape=} {labels.shape=} {prediction.shape=}') + loss = criterion(prediction, labels) loss.backward() self.optimizer.step() diff --git a/wirecell/util/plottools.py b/wirecell/util/plottools.py index 29ff0ff..49922b7 100644 --- a/wirecell/util/plottools.py +++ b/wirecell/util/plottools.py @@ -119,9 +119,32 @@ def __exit__(self, typ, value, traceback): -def pages(name, format=None): +def pages(name, format=None, single=False): + ''' + Return an instance of something like a PdfPages for the given format. + + Use like: + + >>> with pages(filename) as out: + >>> # make a matplotlib figure + >>> out.savefig() + + True multi-page formats (PDF) produce file with the given name. + + When a format that does not support pages (PNG) is requested then a page + number is inserted into the file name. The file name may be given with a + '%d' template to explicitly describe how the page number should be set. + Otherwise, the page number is appended to the base file name just before the + file name extension. + + However, if "single" is True, then this numbering is not performed and each + call to pages.savefig() will overwrite the file. + ''' + if name.endswith(".pdf") or format=="pdf": return PdfPages(name) + if single: + return NameSingleton(name, format) return NameSequence(name)