-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix and make more flexible the constructing of U-Net and DNNROI nets.
And start adding trainer and tracker.
- Loading branch information
1 parent
225eb46
commit 5898771
Showing
12 changed files
with
759 additions
and
231 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/usr/bin/env python | ||
from .train import Classifier as Trainer | ||
from .data import Dataset | ||
from .model import Network |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/env python | ||
''' | ||
Dataset and model specific to DNNROI training. | ||
See for example: | ||
https://www.phy.bnl.gov/~hyu/dunefd/dnn-roi-pdvd/Pytorch-UNet/data/ | ||
''' | ||
# fixme: add support to load from URL | ||
|
||
from wirecell.dnn.data import hdf | ||
|
||
from .transforms import Rec as Rect, Tru as Trut, Params as TrParams | ||
|
||
|
||
class Rec(hdf.Single): | ||
''' | ||
A DNNROI "rec" dataset. | ||
This consists of conventional sigproc results produced by WCT's | ||
OmnibusSigProc in HDF5 "frame file" form. | ||
''' | ||
|
||
file_re = r'.*g4-rec-r(\d+)\.h5' | ||
|
||
path_res = tuple( | ||
r'/(\d+)/%s\d'%tag for tag in [ | ||
'frame_loose_lf', 'frame_mp2_roi', 'frame_mp3_roi'] | ||
) | ||
|
||
def __init__(self, paths, | ||
file_re=None, path_res=None, | ||
trparams: TrParams = Trut.default_params, cache=False): | ||
|
||
dom = hdf.Domain(hdf.ReMatcher(file_re or self.file_re, | ||
path_res or self.path_res), | ||
transform=Rect(trparams), | ||
cache=cache, grad=True, | ||
name="dnnroirec") | ||
super().__init__(dom, paths) | ||
|
||
|
||
class Tru(hdf.Single): | ||
''' | ||
A DNNROI "tru" dataset. | ||
This consists of the target ROI | ||
''' | ||
|
||
file_re = r'.*g4-tru-r(\d+)\.h5' | ||
|
||
path_res = tuple( | ||
r'/(\d+)/%s\d'%tag for tag in ['frame_ductor'] | ||
) | ||
|
||
def __init__(self, paths, threshold = 0.5, | ||
file_re=None, path_res=None, | ||
trparams: TrParams = Trut.default_params, cache=False): | ||
|
||
dom = hdf.Domain(hdf.ReMatcher(file_re or self.file_re, | ||
path_res or self.path_res), | ||
transform=Trut(trparams, threshold), | ||
cache=cache, grad=False, | ||
name="dnnroitru") | ||
|
||
super().__init__(dom, paths) | ||
|
||
|
||
class Dataset(hdf.Multi): | ||
''' | ||
The full DNNROI dataset is effectively zip(Rec,Tru). | ||
''' | ||
def __init__(self, paths, threshold=0.5, cache=False): | ||
# fixme: allow configuring the transforms. | ||
super().__init__(Rec(paths, cache=cache), | ||
Tru(paths, threshold, cache=cache)) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wirecell.dnn.models.unet import UNet | ||
|
||
class Network(nn.Module): | ||
|
||
def __init__(self, shape, n_channels=3, n_classes=1, batch_size=1): | ||
super().__init__() | ||
self.unet = UNet(n_channels, n_classes, shape, batch_size, | ||
batch_norm=True, bilinear=True) | ||
|
||
def forward(self, x): | ||
x = self.unet(x) | ||
return torch.sigmoid(x) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from wirecell.dnn.train import Classifier as Base | ||
|
||
|
||
class Classifier(Base): | ||
''' | ||
The DNNROI classifier | ||
''' | ||
def __init__(self, net, lr=0.1, momentum=0.9, weight_decay=0.0005, **optkwds): | ||
super().__init__(net, lr=lr, momentum=momentum, weight_decay=weight_decay, **optkwds) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#!/usr/bin/env python | ||
''' | ||
The dataset transforms relevant to DNNROI | ||
''' | ||
|
||
from dataclasses import dataclass | ||
from typing import Type, Tuple | ||
|
||
|
||
@dataclass | ||
class DimParams: | ||
''' | ||
Per-dimension parameters for rec and tru dataset transforms. | ||
- crop :: a half-open range as slice | ||
- rebin :: an integer downsampling factor | ||
FYI, common channel counts per wire plane: | ||
- DUNE-VD: 256+320+288 | ||
- DUNE-HD: 800+800+960 | ||
- Uboone: 2400+2400+3456 | ||
- SBND-v1: 1986+1986+1666 | ||
- SBND-v0200: 1984+1984+1664 | ||
- ICARUS: 1056+5600+5600 | ||
''' | ||
crop: slice | ||
rebin: int = 1 | ||
|
||
def __post_init__(self): | ||
if not isinstance(self.crop, slice): | ||
self.crop = slice(*self.crop) | ||
|
||
|
||
@dataclass | ||
class Params: | ||
''' | ||
Common parameters for rec and tru dataset transforms. | ||
elech is for electronics channel dimension | ||
ticks is for sampling period dimension | ||
values are divided by norm | ||
''' | ||
elech: Type[DimParams] | ||
ticks: Type[DimParams] | ||
norm: float = 1.0 | ||
|
||
|
||
class Rec: | ||
''' | ||
The DNNROI "rec" data transformation. | ||
''' | ||
|
||
default_params = Params(DimParams((476, 952), 1), DimParams((0,6000), 10), 4000) | ||
|
||
def __init__(self, params: Params = None): | ||
''' | ||
Arguments: | ||
- params :: a Params | ||
''' | ||
self.params = params or self.default_params | ||
|
||
def crop(self, x): | ||
return x[:, self.params.elech.crop, self.params.ticks.crop] | ||
|
||
def rebin(self, x): | ||
ne, nt = self.params.elech.rebin, self.params.ticks.rebin, | ||
sh = (x.shape[0], # 0 | ||
x.shape[1] // ne, # 1 | ||
ne, # 2 | ||
x.shape[2] // nt, # 3 | ||
nt) # 4 | ||
return x.reshape(sh).mean(4).mean(2) # (imgch, elech_rebinned, ticks_rebinned) | ||
|
||
def transform(self, x): | ||
x = self.crop(x) | ||
x = self.rebin(x) | ||
x = x/self.params.norm | ||
return x | ||
|
||
|
||
def __call__(self, x): | ||
''' | ||
Input and output are shaped: | ||
(# of image channels/layers, # electronic channels, # of time samples) | ||
Last two dimensions of output are rebinned. | ||
''' | ||
return self.transform(x) | ||
|
||
|
||
class Tru(Rec): | ||
''' | ||
The DNNROI "tru" data transformation. | ||
This is same as "rec" but with a thresholding. | ||
''' | ||
|
||
default_params = Params(DimParams((476, 952), 1), DimParams((0,6000), 10), 200) | ||
|
||
def __init__(self, params: Params = None, threshold: float = 0.5): | ||
''' | ||
Arguments (see Rec for more): | ||
- threshold :: threshold for array values to be set to 0 or 1. | ||
''' | ||
super().__init__(params or self.default_params) | ||
self.threshold = threshold | ||
|
||
def __call__(self, x): | ||
x = self.transform(x) | ||
return (x > self.threshold).to(float) | ||
|
||
|
Oops, something went wrong.