From 452846fb26142b6213bc93f585e292abeaa1ec15 Mon Sep 17 00:00:00 2001 From: Daan Wynen Date: Tue, 11 Dec 2018 21:47:01 +0100 Subject: [PATCH] add new models and conversion script. --- convertModels.py | 133 +++++++++++++++++++++++++++++ vgg19_decoders.py | 199 ++++++++++++++++++++++++++++++++++++++++++++ vgg19_normalized.py | 110 ++++++++++++++++++++++++ 3 files changed, 442 insertions(+) create mode 100644 convertModels.py create mode 100644 vgg19_decoders.py create mode 100644 vgg19_normalized.py diff --git a/convertModels.py b/convertModels.py new file mode 100644 index 0000000..6dff8d6 --- /dev/null +++ b/convertModels.py @@ -0,0 +1,133 @@ +import numpy as np +from imageio import imread +from scipy.stats import describe +import torch +import torch.nn as nn +from torch.utils.serialization import load_lua +import torchvision.transforms as tvt + +import modelsNIPS +import vgg19_normalized +import vgg19_decoders + +CHECKPOINT_ENCODER_PY = 'models/vgg19_normalized.pth.tar' +LUA_CHECKPOINT_VGG = 'models/vgg_normalised_conv{}_1.t7' + +TEST_IMAGE = 'images/content/in4.jpg' +image_np = imread(TEST_IMAGE).astype(np.float32) + +ENCODERS = modelsNIPS.encoder1, modelsNIPS.encoder2, modelsNIPS.encoder3, modelsNIPS.encoder4, modelsNIPS.encoder5 +DECODERS = modelsNIPS.decoder1, modelsNIPS.decoder2, modelsNIPS.decoder3, modelsNIPS.decoder4, modelsNIPS.decoder5 + +# put image into [0, 1], but don't center or normalize like for other nets +trans = tvt.ToTensor() +image_pt = trans(image_np).unsqueeze(0) + +def convert_encoder(): + vgg_lua = [load_lua(LUA_CHECKPOINT_VGG.format(k)) for k in range(1, 6)] + vgg_lua_ = [e(vl) for e, vl in zip(ENCODERS, vgg_lua)] + + vgg_py = vgg19_normalized.VGG19_normalized() + + matching = { + vgg_py.blocks['conv1_1']: 2, + vgg_py.blocks['conv1_2']: 5, + + vgg_py.blocks['conv2_1']: 9, + vgg_py.blocks['conv2_2']: 12, + + vgg_py.blocks['conv3_1']: 16, + vgg_py.blocks['conv3_2']: 19, + vgg_py.blocks['conv3_3']: 22, + vgg_py.blocks['conv3_4']: 25, + + vgg_py.blocks['conv4_1']: 29, + vgg_py.blocks['conv4_2']: 32, + vgg_py.blocks['conv4_3']: 35, + vgg_py.blocks['conv4_4']: 38, + + vgg_py.blocks['conv5_1']: 42 + } + + for torch_conv, lua_conv_i in matching.items(): + weights = nn.Parameter(vgg_lua[4].get(lua_conv_i).weight.float()) + bias = nn.Parameter(vgg_lua[4].get(lua_conv_i).bias.float()) + torch_conv.load_state_dict({'weight': weights, 'bias': bias}) + + torch.save(vgg_py.state_dict(), CHECKPOINT_ENCODER_PY) + + for k in range(1, 6): + print(f'encoder {k}') + e_lua = vgg_lua_[k-1] + with torch.no_grad(): + al = e_lua(image_pt) + ap = vgg_py(image_pt, targets=f'relu{k}_1') + assert al.shape == ap.shape, (al.shape, ap.shape) + diff = np.abs((al - ap)) + print(describe(diff.flatten())) + print(np.percentile(diff, 99)) + print() + +def convert_decoder(K): + print(f'converting decoder from layer {K}') + decoderK_lua = load_lua(f'models/feature_invertor_conv{K}_1.t7') + decoderK_legacy = DECODERS[K-1](decoderK_lua) + decoderK_py = vgg19_decoders.DECODERS[K-1]() + + matching = { + 'conv5_1': -41, + + 'conv4_4': -37, + 'conv4_3': -34, + 'conv4_2': -31, + 'conv4_1': -28, + + 'conv3_4': -24, + 'conv3_3': -21, + 'conv3_2': -18, + 'conv3_1': -15, + + 'conv2_2': -11, + 'conv2_1': -8, + + 'conv1_2': -4, + 'conv1_1': -1 + + } + + for torch_conv, lua_conv_i in matching.items(): + if -lua_conv_i >= len(decoderK_lua): + continue + print(f' {torch_conv}') + weights = nn.Parameter(decoderK_lua.get(lua_conv_i).weight.float()) + bias = nn.Parameter(decoderK_lua.get(lua_conv_i).bias.float()) + decoderK_py.blocks[torch_conv].load_state_dict({'weight': weights, 'bias': bias}) + + torch.save(decoderK_py.state_dict(), f'models/vgg19_normalized_decoder{K}.pth.tar') + + encoder = vgg19_normalized.VGG19_normalized() + encoder.load_state_dict(torch.load(CHECKPOINT_ENCODER_PY)) + + print(f'testing encoding/decoding at layer {K}') + + with torch.no_grad(): + features = encoder(image_pt, targets=f'relu{K}_1') + rgb_legacy = decoderK_legacy(features) + rgb_py = decoderK_py(features) + assert rgb_legacy.shape == rgb_py.shape, (rgb_legacy.shape, rgb_py.shape) + diff = np.abs((rgb_legacy - rgb_py).numpy()) + print(describe(diff.flatten())) + print(np.percentile(diff, 99)) + print() + +def main(): + convert_encoder() + + for K in range(1, 6): + convert_decoder(K) + + print('DONE') + + +if __name__ == '__main__': + main() diff --git a/vgg19_decoders.py b/vgg19_decoders.py new file mode 100644 index 0000000..c5bce7c --- /dev/null +++ b/vgg19_decoders.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +from collections import OrderedDict + +class VGG19Decoder1(nn.Module): + + def __init__(self): + + super(VGG19Decoder1, self).__init__() + + # input shape originally 224 x 224 + + self.blocks = OrderedDict([ # {{{ + ('pad1_1', nn.ReflectionPad2d(1)), # 226 x 226 + ('conv1_1', nn.Conv2d(64, 3, 3, 1, 0)), # 224 x 224 + ]) # }}} + + self.seq = nn.Sequential(self.blocks) + + + def forward(self, x, targets=None): + return self.seq(x) + +class VGG19Decoder2(nn.Module): + + def __init__(self): + + super(VGG19Decoder2, self).__init__() + + # input shape originally 224 x 224 + + self.blocks = OrderedDict([ # {{{ + ('pad2_1', nn.ReflectionPad2d(1)),# {{{}}} + ('conv2_1', nn.Conv2d(128, 64, 3, 1, 0)), + ('relu2_1', nn.ReLU(inplace=True)), # 112 x 112 + + ('unpool1', nn.Upsample(scale_factor=2)), # 112 x 112 + ('pad1_2', nn.ReflectionPad2d(1)), + ('conv1_2', nn.Conv2d(64, 64, 3, 1, 0)), + ('relu1_2', nn.ReLU(inplace=True)), # 224 x 224 + ('pad1_1', nn.ReflectionPad2d(1)), # 226 x 226 + ('conv1_1', nn.Conv2d(64, 3, 3, 1, 0)), # 224 x 224 + ]) # }}} + + self.seq = nn.Sequential(self.blocks) + + + def forward(self, x, targets=None): + return self.seq(x) + +class VGG19Decoder3(nn.Module): + + def __init__(self): + + super(VGG19Decoder3, self).__init__() + + # input shape originally 224 x 224 + + self.blocks = OrderedDict([ # {{{ + ('pad3_1', nn.ReflectionPad2d(1)), + ('conv3_1', nn.Conv2d(256, 128, 3, 1, 0)), + ('relu3_1', nn.ReLU(inplace=True)), # 56 x 56 + + ('unpool2', nn.Upsample(scale_factor=2)), # 56 x 56 + ('pad2_2', nn.ReflectionPad2d(1)), + ('conv2_2', nn.Conv2d(128, 128, 3, 1, 0)), + ('relu2_2', nn.ReLU(inplace=True)), # 112 x 112 + ('pad2_1', nn.ReflectionPad2d(1)),# {{{}}} + ('conv2_1', nn.Conv2d(128, 64, 3, 1, 0)), + ('relu2_1', nn.ReLU(inplace=True)), # 112 x 112 + + ('unpool1', nn.Upsample(scale_factor=2)), # 112 x 112 + ('pad1_2', nn.ReflectionPad2d(1)), + ('conv1_2', nn.Conv2d(64, 64, 3, 1, 0)), + ('relu1_2', nn.ReLU(inplace=True)), # 224 x 224 + ('pad1_1', nn.ReflectionPad2d(1)), # 226 x 226 + ('conv1_1', nn.Conv2d(64, 3, 3, 1, 0)), # 224 x 224 + ]) # }}} + + self.seq = nn.Sequential(self.blocks) + + + def forward(self, x, targets=None): + return self.seq(x) + +class VGG19Decoder4(nn.Module): + + def __init__(self): + + super(VGG19Decoder4, self).__init__() + + # input shape originally 224 x 224 + + self.blocks = OrderedDict([ # {{{ + ('pad4_1', nn.ReflectionPad2d(1)), + ('conv4_1', nn.Conv2d(512, 256, 3, 1, 0)), + ('relu4_1', nn.ReLU(inplace=True)), # 28 x 28 + + ('unpool3', nn.Upsample(scale_factor=2)), # 28 x 28 + ('pad3_4', nn.ReflectionPad2d(1)), + ('conv3_4', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_4', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_3', nn.ReflectionPad2d(1)), + ('conv3_3', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_3', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_2', nn.ReflectionPad2d(1)), + ('conv3_2', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_2', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_1', nn.ReflectionPad2d(1)), + ('conv3_1', nn.Conv2d(256, 128, 3, 1, 0)), + ('relu3_1', nn.ReLU(inplace=True)), # 56 x 56 + + ('unpool2', nn.Upsample(scale_factor=2)), # 56 x 56 + ('pad2_2', nn.ReflectionPad2d(1)), + ('conv2_2', nn.Conv2d(128, 128, 3, 1, 0)), + ('relu2_2', nn.ReLU(inplace=True)), # 112 x 112 + ('pad2_1', nn.ReflectionPad2d(1)),# {{{}}} + ('conv2_1', nn.Conv2d(128, 64, 3, 1, 0)), + ('relu2_1', nn.ReLU(inplace=True)), # 112 x 112 + + ('unpool1', nn.Upsample(scale_factor=2)), # 112 x 112 + ('pad1_2', nn.ReflectionPad2d(1)), + ('conv1_2', nn.Conv2d(64, 64, 3, 1, 0)), + ('relu1_2', nn.ReLU(inplace=True)), # 224 x 224 + ('pad1_1', nn.ReflectionPad2d(1)), # 226 x 226 + ('conv1_1', nn.Conv2d(64, 3, 3, 1, 0)), # 224 x 224 + ]) # }}} + + self.seq = nn.Sequential(self.blocks) + + + def forward(self, x, targets=None): + return self.seq(x) + +class VGG19Decoder5(nn.Module): + + def __init__(self): + + super(VGG19Decoder5, self).__init__() + + # input shape originally 224 x 224 + + self.blocks = OrderedDict([ # {{{ + ('pad5_1', nn.ReflectionPad2d(1)), + ('conv5_1', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu5_1', nn.ReLU(inplace=True)), # 14 x 14 + + ('unpool4', nn.Upsample(scale_factor=2)), + ('pad4_4', nn.ReflectionPad2d(1)), + ('conv4_4', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu4_4', nn.ReLU(inplace=True)), # 28 x 28 + ('pad4_3', nn.ReflectionPad2d(1)), + ('conv4_3', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu4_3', nn.ReLU(inplace=True)), # 28 x 28 + ('pad4_2', nn.ReflectionPad2d(1)), + ('conv4_2', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu4_2', nn.ReLU(inplace=True)), # 28 x 28 + ('pad4_1', nn.ReflectionPad2d(1)), + ('conv4_1', nn.Conv2d(512, 256, 3, 1, 0)), + ('relu4_1', nn.ReLU(inplace=True)), # 28 x 28 + + ('unpool3', nn.Upsample(scale_factor=2)), # 28 x 28 + ('pad3_4', nn.ReflectionPad2d(1)), + ('conv3_4', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_4', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_3', nn.ReflectionPad2d(1)), + ('conv3_3', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_3', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_2', nn.ReflectionPad2d(1)), + ('conv3_2', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_2', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_1', nn.ReflectionPad2d(1)), + ('conv3_1', nn.Conv2d(256, 128, 3, 1, 0)), + ('relu3_1', nn.ReLU(inplace=True)), # 56 x 56 + + ('unpool2', nn.Upsample(scale_factor=2)), # 56 x 56 + ('pad2_2', nn.ReflectionPad2d(1)), + ('conv2_2', nn.Conv2d(128, 128, 3, 1, 0)), + ('relu2_2', nn.ReLU(inplace=True)), # 112 x 112 + ('pad2_1', nn.ReflectionPad2d(1)),# {{{}}} + ('conv2_1', nn.Conv2d(128, 64, 3, 1, 0)), + ('relu2_1', nn.ReLU(inplace=True)), # 112 x 112 + + ('unpool1', nn.Upsample(scale_factor=2)), # 112 x 112 + ('pad1_2', nn.ReflectionPad2d(1)), + ('conv1_2', nn.Conv2d(64, 64, 3, 1, 0)), + ('relu1_2', nn.ReLU(inplace=True)), # 224 x 224 + ('pad1_1', nn.ReflectionPad2d(1)), # 226 x 226 + ('conv1_1', nn.Conv2d(64, 3, 3, 1, 0)), # 224 x 224 + ]) # }}} + + self.seq = nn.Sequential(self.blocks) + + + def forward(self, x, targets=None): + return self.seq(x) + + +DECODERS = VGG19Decoder1, VGG19Decoder2, VGG19Decoder3, VGG19Decoder4, VGG19Decoder5 diff --git a/vgg19_normalized.py b/vgg19_normalized.py new file mode 100644 index 0000000..76ec728 --- /dev/null +++ b/vgg19_normalized.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +from collections import OrderedDict + +class VGG19_normalized(nn.Module): + def __init__(self): + """ + VGG19 normalized. + Takes RGB within [0, 1] as input. + Do NOT normalize the data as with other VGG models! + """ + + super(VGG19_normalized,self).__init__() + + #self.preprocess_weight = + self.register_buffer( + 'preprocess_weight', + torch.FloatTensor([[[[ 0.]], [[ 0.]], [[255.]]], + [[[ 0.]], [[255.]], [[ 0.]]], + [[[255.]], [[ 0.]], [[ 0.]]]])) + #self.preprocess_bias = + self.register_buffer( + 'preprocess_bias', + torch.FloatTensor([-103.9390, -116.7790, -123.6800])) + + # input shape originally 224 x 224 + + self.blocks = OrderedDict([ + ('pad1_1', nn.ReflectionPad2d(1)), # 226 x 226 + ('conv1_1', nn.Conv2d(3, 64, 3, 1, 0)), + ('relu1_1', nn.ReLU(inplace=True)), # 224 x 224 + ('pad1_2', nn.ReflectionPad2d(1)), + ('conv1_2', nn.Conv2d(64, 64, 3, 1, 0)), + ('relu1_2', nn.ReLU(inplace=True)), # 224 x 224 + ('pool1', nn.MaxPool2d(kernel_size=2, stride=2)), # 112 x 112 + + ('pad2_1', nn.ReflectionPad2d(1)),# {{{}}} + ('conv2_1', nn.Conv2d(64, 128, 3, 1, 0)), + ('relu2_1', nn.ReLU(inplace=True)), # 112 x 112 + ('pad2_2', nn.ReflectionPad2d(1)), + ('conv2_2', nn.Conv2d(128, 128, 3, 1, 0)), + ('relu2_2', nn.ReLU(inplace=True)), # 112 x 112 + ('pool2', nn.MaxPool2d(kernel_size=2, stride=2)), # 56 x 56 + + ('pad3_1', nn.ReflectionPad2d(1)), + ('conv3_1', nn.Conv2d(128, 256, 3, 1, 0)), + ('relu3_1', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_2', nn.ReflectionPad2d(1)), + ('conv3_2', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_2', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_3', nn.ReflectionPad2d(1)), + ('conv3_3', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_3', nn.ReLU(inplace=True)), # 56 x 56 + ('pad3_4', nn.ReflectionPad2d(1)), + ('conv3_4', nn.Conv2d(256, 256, 3, 1, 0)), + ('relu3_4', nn.ReLU(inplace=True)), # 56 x 56 + ('pool3', nn.MaxPool2d(kernel_size=2, stride=2)), # 28 x 28 + + ('pad4_1', nn.ReflectionPad2d(1)), + ('conv4_1', nn.Conv2d(256, 512, 3, 1, 0)), + ('relu4_1', nn.ReLU(inplace=True)), # 28 x 28 + ('pad4_2', nn.ReflectionPad2d(1)), + ('conv4_2', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu4_2', nn.ReLU(inplace=True)), # 28 x 28 + ('pad4_3', nn.ReflectionPad2d(1)), + ('conv4_3', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu4_3', nn.ReLU(inplace=True)), # 28 x 28 + ('pad4_4', nn.ReflectionPad2d(1)), + ('conv4_4', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu4_4', nn.ReLU(inplace=True)), # 28 x 28 + ('pool4', nn.MaxPool2d(kernel_size=2, stride=2)), # 14 x 14 + + ('pad5_1', nn.ReflectionPad2d(1)), + ('conv5_1', nn.Conv2d(512, 512, 3, 1, 0)), + ('relu5_1', nn.ReLU(inplace=True)), # 14 x 14 + ]) + + self.seq = nn.Sequential(self.blocks) + + + def forward(self, x, targets=None): + # don't want this one to be trainable so we don't make it into parameters + out = nn.functional.conv2d(x, + weight=self.preprocess_weight, + bias=self.preprocess_bias) + + # by default, just run the whole thing + targets = targets or 'relu5_1' + + if isinstance(targets, str): + assert targets in self.blocks.keys(), f'"{targets}" is not a valid target' + for n, b in self.blocks.items(): + out = b(out) + if n == targets: + return out + + + for t in targets: + assert t in self.blocks.keys(), f'"{t}" is not a valid target' + + results = dict() + for n, b in self.blocks.items(): + out = b(out) + if n in targets: + results[n] == out + if len(results) == len(set(targets)): + break + + results = [results[t] for t in targets] + return results