Skip to content

Commit

Permalink
Initial dnn main and UNet paper model
Browse files Browse the repository at this point in the history
  • Loading branch information
brettviren committed Oct 25, 2024
1 parent 2891c2c commit f3cae6c
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
'wirecell-aux = wirecell.aux.__main__:main',
'wirecell-test = wirecell.test.__main__:main',
'wirecell-ls4gan = wirecell.ls4gan.__main__:main',
'wirecell-dnn = wirecell.dnn.__main__:main',
]
)
)
Expand Down
Empty file added wirecell/dnn/__init__.py
Empty file.
69 changes: 69 additions & 0 deletions wirecell/dnn/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

import click
from wirecell.util.cli import context, log, jsonnet_loader

@context("dnn")
def cli(ctx):
'''
Wire Cell Deep Neural Network commands.
'''
pass

@cli.command('train-dnnroi')
@click.option("-c", "--config",
type=click.Path(),
help="Set configuration file")
@click.pass_context
def train_dnnroi(ctx, config):
'''
Train the DNNROI model.
'''
log.info(config)

@cli.command("vizmod")
@click.option("-s","--size", default=572, help="Number of input pixel (rows or columns)")
@click.option("-c","--channels", default=3, help="Number of input image channels")
@click.option("-C","--classes", default=6, help="Number of output classes")
@click.option("-b","--batch", default=1, help="Number of batch images")
@click.option("-o","--output", default=None, help="File name to fill with GraphViz dot")
@click.option("-m","--model", default="UNet",
type=click.Choice(["UNet","UsuyamaUNet", "MilesialUNet","list"]))
def vizmod(size, channels, classes, batch, output, model):
'''
Produce a text summary and if -o/--output given also a GraphViz diagram of a named model.
'''
import torch
from wirecell.dnn import models

if model == "list":
for one in dir(models):
if one[0].isupper():
print(one)
return

Mod = getattr(models, model)

mod = Mod(channels, classes, size)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
mod = mod.to(device)

from torchsummary import summary

shape = (channels, size, size)
summary(mod, input_size=shape, device=device)

if output:
from torchview import draw_graph
shape = (batch, channels, size, size)
gr = draw_graph(mod, input_size=shape, device=device)
with open(output, "w") as fp:
fp.write(str(gr.visual_graph))


def main():
cli(obj=dict())

if '__main__' == __name__:
main()
3 changes: 3 additions & 0 deletions wirecell/dnn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .unet import UNet
# from .usuyama import UNet as UsuyamaUNet
# from .milesial import UNet as MilesialUNet
292 changes: 292 additions & 0 deletions wirecell/dnn/models/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
#!/usr/bin/env python
'''
The Ronneberger, Fischer and Brox U-Net
Unlike several other implementations using the name "U-Net" that one runs
across, this one tries to exactly replicate what is in the paper.
Do NOT attempt to "improve" this model except where it may deviate from what is
described in the paper.
https://arxiv.org/abs/1505.04597
https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png
The following labels are used to identify units of the network and refers to the
u-net-architecture.png figure.
- dconv :: "double convolution" (dark blue arrow pair), two sub units of 3x3
convolution + ReLU. This makes up each major unit of the "U". The output of
this fans out to "dsamp" and to "skip".
- dsamp :: "down sampling" (red arrow), the 2x2 max pool downsampling that
connects output of one dconv to input of next dconv on the downward leg of the
"U".
- bottom :: the dconv making up the apex or bottom of the U.
- usamp :: "up sampling" (green arrow), the "up-conv 2x2" that input from dconv
result and output to umerge.
- skip :: "skip connection" (gray arrow), the center crop of dconv output and
provides input to umerge. A "skip level" counts the skips from 0 starting at
the top of the U. The "bottom" can be considered a skip level for purposes of
calculating the output size of its dconv.
- umerge :: "up merge" (gray+green arrows), concatenation of the skip result
with the usamp result and provides input to an dconv on the upward leg.
The default produces U-Net. However, the model is parameterized by the number
of skip connections. U-Net has four.
'''


import torch
import torch.nn as nn
from torch.nn.functional import grid_sample


def down_out_size(orig_size, skip_level):
'''
Return the output size from a down unit at a given skip level.
Skip level counts from 0 transfer "over" the U via skip connection or bottom.
'''
size = orig_size
for dskip in range(skip_level + 1):
if dskip:
size = size // 2
size = size - 4
return size

def up_in_size(orig_size, skip_level, nlevels = 4):
'''
Return the input size to an up unit (output of a skip) at a given skip level.
The nlevels counts the number of skip connections across the U.
'''
size = down_out_size(orig_size, nlevels)
for uskip in range( nlevels - skip_level):
if uskip:
size = size - 4
size = size * 2
return size



def dimensions(in_channels = 1, in_size = 572, nskips = 4):
'''
Calculate image channel and pixel dimensions for elements of the U-Net.
- size :: the size of both input image dimensions (572 for U-Net paper).
- nskips :: the number of skip connections (4 for U-Net paper)
This returns four lists of size 2*nskips+1. Each element of a list
corresponds to one major "dconv" unit as we go along the U: nskips "down",
one "bottom" and nskips "up". The lists are:
- number of input channels
- number of output channels
- input size
- output size
The [nskips] element refers to the bottom dconv.
See skip_dimensions() to form similar lists from the output of this function
for the skip connections.
Note, the output segmentation map is excluded. The final element in the
lists refers to the top up dconv.
'''
chans_down_in = [in_channels] + [2**(6+n) for n in range(nskips)] # includes bottom
chans_down_out = [2**(6+n) for n in range(nskips+1)]
chans_up_in = list(chans_down_out[1:])
chans_up_in.reverse()
chans_in = chans_down_in + chans_up_in
chans_up_out = chans_down_in[1:]
chans_up_out.reverse()
chans_out = chans_down_out + chans_up_out

size_in = [in_size]
size_out = []
for skip in range(nskips):
siz = size_in[-1] - 4 # dconv reduction
size_out.append(siz)
size_in.append(siz // 2) # max pool reduction
size_out.append(size_in[-1] - 4) # bottom out
for rskip in range(nskips):
size_in.append(size_out[-1] * 2) # up conv
size_out.append(size_in[-1] - 4) # dconv reduction

return (chans_in, chans_out, size_in, size_out)


def skip_dimensions(dims):
'''
Reformat the output of dimensions() to the same form but for the skip
connections in order of skip level.
'''
(chans_in, chans_out, size_in, size_out) = dims

nskips = (len(chans_in)-1)//2

schans_in = chans_out[:nskips]
schans_out = schans_in # skips preserve channel dim

ssize_in = size_out[:nskips]
ssize_out = size_in[-nskips:]
ssize_out.reverse()
return (schans_in, schans_out, ssize_in, ssize_out)



def dconv(in_channels, out_channels, kernel_size = 3, padding = 0):
'''
The double "conv 3x3, ReLU" unit.
'''
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding),
nn.ReLU(inplace=True)
)

def dsamp():
'''
The "down sampling".
'''
return nn.MaxPool2d(2)


def usamp(in_ch):
'''
The "up sampling".
'''
return nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
# return nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)


def build_grid(source_size, target_size, batch_size = 1):
'''
Map output pixels to input pixels for cropping by grid_sample().
This assumes square images of given size.
'''
# simplified version of what is given in
# https://discuss.pytorch.org/t/cropping-a-minibatch-of-images-each-image-a-bit-differently/12247
k = float(target_size)/float(source_size)
direct = torch.linspace(-k,k,target_size).unsqueeze(0).repeat(target_size,1).unsqueeze(-1)
grid = torch.cat([direct,direct.transpose(1,0)],dim=2).unsqueeze(0)
return grid.repeat(batch_size, 1, 1, 1)

class skip(nn.Module):
'''
The "skip connection" providing a core cropping.
'''
def __init__(self, source_size, target_size, batch_size=1):
super().__init__()
margin = (source_size - target_size)//2
self.crop = slice(margin, margin+target_size)

# A fancier way to do it which, but why?
# self.register_buffer('g', build_grid(source_size, target_size, batch_size))
# grid should have shape: (nbatch, nrows, ncols, 2)

def forward(self, x):
# x must be (nbatch, nchannel, nrows, ncols)
# print(f'grid: {self.g.shape} {self.g.dtype} {self.g.device}')
# print(f'data: {x.shape} {x.dtype} {x.device}')
# return grid_sample(x, self.g, align_corners=True, mode='nearest')
return x[:,:,self.crop,self.crop]


class umerge(nn.Module):
'''
The "upsample merge" of the outputs from a skip and a dconv.
'''
def __init__(self, nchannels):
'''
Give number of channels in the input to the upsampling port.
'''
super().__init__()
self._nchannels = nchannels
self.upsamp = nn.ConvTranspose2d(nchannels, nchannels//2, 2, stride=2)

def forward(self, over, up):
up = self.upsamp(up)
cat = torch.cat((over, up), dim=1)
return cat


class UNet(nn.Module):
'''
U-Net model from the paper.
'''

def __init__(self, n_channels=3, n_classes=6, in_size=572, batch_size=1, nskips=4):
super().__init__()

self.nskips = nskips
dims = dimensions(n_channels, in_size, nskips)

# The major elements of the U
chans_in, chans_out, _, _ = dims

# Note; we use setattr to make sure PyTorch summary finds the submodules.

# The downward leg of the U.
for ind in range(nskips):
setattr(self, f'downleg_{ind}', dconv(chans_in[ind], chans_out[ind]))

# The bottom of the U
self.bottom = dconv(chans_in[nskips], chans_out[nskips])

# The upward leg of the U.
for count, ind in enumerate(range(nskips+1, 2*nskips+1)):
setattr(self, f'upleg_{count}', dconv(chans_in[ind], chans_out[ind]))

# The skip connections get applied top-down
schans_in, schans_out, ssize_in, ssize_out = skip_dimensions(dims)
for ind, ss in enumerate(zip(ssize_in, ssize_out)):
setattr(self, f'skip_{ind}', skip(*ss, batch_size=batch_size))

# And the merges are applied bottom-up.
# We bake in the rule that upsample input has 2x the number of channels as the skip output.
umerges = [umerge(2*nc) for nc in schans_out]
umerges.reverse()
for ind, um in enumerate(umerges):
setattr(self, f'umerge_{ind}', um);

# Downsampler is data-independent and reused.
self.dsamp = dsamp()

self.segmap = nn.Conv2d(chans_out[-1], n_classes, 1)

def getm(self, name, ind):
return getattr(self, f'{name}_{ind}')

def forward(self, x):

dskips = list()

for ind in range(self.nskips):
dl = self.getm("downleg", ind)
dout = dl(x)
x = self.dsamp(dout)
sm = self.getm("skip", ind)
dskip = sm(dout)
dskips.append( dskip )

x = self.bottom(x)

dskips.reverse() # bottom-up
for ind in range(self.nskips):
s = dskips[ind]
um = self.getm("umerge", ind)
x = um(s, x)
ul = self.getm("upleg", ind)
x = ul(x)

x = self.segmap(x)
return x

0 comments on commit f3cae6c

Please sign in to comment.