Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Pytorch 1.0 #16

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
models.zip
models_pytorch.zip
# python artifacts
__pychache__
*.pyc

# Temporary Files
*~
*.swp
*.swo
135 changes: 110 additions & 25 deletions Loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from PIL import Image
import skimage.transform
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.utils.data as data
Expand All @@ -24,12 +25,10 @@ def __init__(self,contentPath,stylePath,fineSize):
self.image_list = [x for x in listdir(contentPath) if is_image_file(x)]
self.stylePath = stylePath
self.fineSize = fineSize
#self.normalize = transforms.Normalize(mean=[103.939,116.779,123.68],std=[1, 1, 1])
#normalize = transforms.Normalize(mean=[123.68,103.939,116.779],std=[1, 1, 1])
self.prep = transforms.Compose([
transforms.Scale(fineSize),
Rescale(fineSize, mode='PIL'),
CropModulus(16),
transforms.ToTensor(),
#transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to BGR
])

def __getitem__(self,index):
Expand All @@ -38,28 +37,114 @@ def __getitem__(self,index):
contentImg = default_loader(contentImgPath)
styleImg = default_loader(styleImgPath)

# resize
if(self.fineSize != 0):
w,h = contentImg.size
if(w > h):
if(w != self.fineSize):
neww = self.fineSize
newh = int(h*neww/w)
contentImg = contentImg.resize((neww,newh))
styleImg = styleImg.resize((neww,newh))
else:
if(h != self.fineSize):
newh = self.fineSize
neww = int(w*newh/h)
contentImg = contentImg.resize((neww,newh))
styleImg = styleImg.resize((neww,newh))


# Preprocess Images
contentImg = transforms.ToTensor()(contentImg)
styleImg = transforms.ToTensor()(styleImg)
return contentImg.squeeze(0),styleImg.squeeze(0),self.image_list[index]
contentImg = self.prep(contentImg)
styleImg = self.prep(styleImg)
return contentImg, styleImg, self.image_list[index]

def __len__(self):
# You should change 0 to the total size of your dataset.
return len(self.image_list)


class CropModulus(object):

def __init__(self, crop_modulus, mode='PIL'):
assert mode in ['PIL', 'HWC', 'CHW']
self.mode = mode
self.crop_modulus = crop_modulus

def __call__(self, im):
if self.mode == 'PIL':
W, H = im.size
elif self.mode == 'HWC':
H, W = im.shape[:2]
elif self.mode == 'HWC':
H, W = im.shape[1:3]
Hmod = H - H % self.crop_modulus
Wmod = W - W % self.crop_modulus
border_x = (W - Wmod) // 2
border_y = (H - Hmod) // 2
end_x = border_x + Wmod
end_y = border_y + Hmod
crop_box = (border_x, border_y, end_x, end_y)
if self.mode == 'PIL':
return im.crop(crop_box)
elif self.mode == 'HWC':
return im[border_y:end_y, border_x:end_x, :]
else: # self.mode == 'HWC':
return im[: border_y:end_y, border_x:end_x]


class Rescale(object):
"""
Rescale the image in a sample to a given size, preserving aspect ratio.
If the input image's smaller dimension goes below the minimum size, scale
*up* so that it matches.

Args:
target_size (int): Desired output size.
The smaller (or bigger, if wished) of the image edges is matched
to output_size keeping aspect ratio the same.
If that puts the smaller edge below the minimum size,
then the smaller of the edges is matched to output_size.
"""

def __init__(self,
target_size,
max_size=2048,
min_size=224,
scaling='smaller_side',
mode='numpy',
interpolation=Image.BILINEAR
):
assert mode in ['PIL', 'numpy', 'torch'], mode
self.mode = mode
self.interpolation = interpolation
assert target_size >= 0, f'invalid target_size {target_size}'
assert scaling in ['smaller_side', 'bigger_side']
self.scaling = scaling
assert min_size <= target_size, f'min_size = {min_size} <= target_size = {target_size}. Baaaad idea!'
assert max_size <= 0 or target_size <= max_size, (target_size, max_size)
self.min_size = min_size
self.target_size = target_size
self.max_size = max_size

def target_shape(self, H, W, scaling=None):
scaling = scaling or self.scaling
if (scaling == 'bigger_side' and H > W) or (scaling == 'smaller_side' and H < W):
Wnew = int(np.round(W/H * self.target_size))
Hnew = self.target_size
else:
Wnew = self.target_size
Hnew = int(np.round(H/W * self.target_size))
return Hnew, Wnew

def __call__(self, image):
if self.mode == 'numpy':
H, W = image.shape[:2]
elif self.mode == 'torch':
H, W = image.shape[1:3]
else: #self.mode == 'PIL':
W, H = image.size
scaling = self.scaling
target_size = self.target_size

Hnew, Wnew = self.target_shape(H, W)

if Wnew < self.min_size or Hnew < self.min_size:
# this can only happen with scaling=='bigger_side'
# print(f'WARNING: image is too small after scaling. scaling UP instead of down.')
Hnew, Wnew = self.target_shape(H, W, 'smaller_side')

if self.mode == 'numpy':
return skimage.transform.resize(image, (Hnew, Wnew), preserve_range=True)
elif self.mode == 'torch':
# apparently, there is no simple way to resize a tensor 0_o
image = tvt.functional.to_pil_image(image, mode='RGB')
image = image.resize((Wnew, Hnew), resample=self.interpolation)
image = tvt.functional.to_tensor(image)
return image
else: # self.mode in ['PIL', 'torch']:
return image.resize((Wnew, Hnew), resample=self.interpolation)


18 changes: 15 additions & 3 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
## Universal Style Transfer

This is the Pytorch implementation of [Universal Style Transfer via Feature Transforms](https://arxiv.org/pdf/1705.08086.pdf).
This is a **modified** Pytorch implementation of [Universal Style Transfer via Feature Transforms](https://arxiv.org/pdf/1705.08086.pdf).

Official Torch implementation can be found [here](https://github.com/Yijunmaverick/UniversalStyleTransfer) and Tensorflow implementation can be found [here](https://github.com/eridgd/WCT-TF).
**It makes some modifications:**

* Slightly improved parametrization introduced in [Unsupervised Learning of Artistic Styles with Archetypal Style Analysis](https://arxiv.org/abs/1805.11155)
to control the preservation of detail vs the strength of stylization.
This is most useful if you're modifying the style of an image that is *already* an artwork.
But may also be of interest to preserve detail in photos.
For the original parametrization, see [@sunshineatnoon's repository](https://github.com/sunshineatnoon/PytorchWCT) (or go back through the git log) until I manage to clean this up and have both neatly next to each other.
* Improved feature transforms as described by Lu et al. in [A Closed-form Solution to Universal Style Transfer](https://arxiv.org/abs/1906.00668).
This specifically leads to better contour preservation.



The official Torch implementation can be found [here](https://github.com/Yijunmaverick/UniversalStyleTransfer) and Tensorflow implementation can be found [here](https://github.com/eridgd/WCT-TF).

## Prerequisites
- [Pytorch](http://pytorch.org/)
- [torchvision](https://github.com/pytorch/vision)
- Pretrained encoder and decoder [models](https://drive.google.com/file/d/1M5KBPfqrIUZqrBZf78CIxLrMUT4lD4t9/view?usp=sharing) for image reconstruction only (download and uncompress them under models/)
- [scikit-image](https://scikit-image.org)
- CUDA + CuDNN

## Prepare images
Expand Down
151 changes: 73 additions & 78 deletions WCT.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,38 +1,42 @@
#! /usr/bin/env python3

import os
import torch
import argparse
import pprint
from PIL import Image
from torch.autograd import Variable
import torchvision.utils as vutils
import torchvision.datasets as datasets
from Loader import Dataset
from util import *
import scipy.misc
from torch.utils.serialization import load_lua
import time

parser = argparse.ArgumentParser(description='WCT Pytorch')
parser.add_argument('--contentPath',default='images/content',help='path to train')
parser.add_argument('--stylePath',default='images/style',help='path to train')
parser.add_argument('--workers', default=2, type=int, metavar='N',help='number of data loading workers (default: 4)')
parser.add_argument('--vgg1', default='models/vgg_normalised_conv1_1.t7', help='Path to the VGG conv1_1')
parser.add_argument('--vgg2', default='models/vgg_normalised_conv2_1.t7', help='Path to the VGG conv2_1')
parser.add_argument('--vgg3', default='models/vgg_normalised_conv3_1.t7', help='Path to the VGG conv3_1')
parser.add_argument('--vgg4', default='models/vgg_normalised_conv4_1.t7', help='Path to the VGG conv4_1')
parser.add_argument('--vgg5', default='models/vgg_normalised_conv5_1.t7', help='Path to the VGG conv5_1')
parser.add_argument('--decoder5', default='models/feature_invertor_conv5_1.t7', help='Path to the decoder5')
parser.add_argument('--decoder4', default='models/feature_invertor_conv4_1.t7', help='Path to the decoder4')
parser.add_argument('--decoder3', default='models/feature_invertor_conv3_1.t7', help='Path to the decoder3')
parser.add_argument('--decoder2', default='models/feature_invertor_conv2_1.t7', help='Path to the decoder2')
parser.add_argument('--decoder1', default='models/feature_invertor_conv1_1.t7', help='Path to the decoder1')
parser.add_argument('--encoder', default='models/vgg19_normalized.pth.tar', help='Path to the VGG conv1_1')
parser.add_argument('--decoder5', default='models/vgg19_normalized_decoder5.pth.tar', help='Path to the decoder5')
parser.add_argument('--decoder4', default='models/vgg19_normalized_decoder4.pth.tar', help='Path to the decoder4')
parser.add_argument('--decoder3', default='models/vgg19_normalized_decoder3.pth.tar', help='Path to the decoder3')
parser.add_argument('--decoder2', default='models/vgg19_normalized_decoder2.pth.tar', help='Path to the decoder2')
parser.add_argument('--decoder1', default='models/vgg19_normalized_decoder1.pth.tar', help='Path to the decoder1')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--transform-method', choices=['original', 'closed-form'], default='original',
help=('How to whiten and color the features. "original" for the formulation of Li et al. ( https://arxiv.org/abs/1705.08086 ) '
'or "closed-form" for method of Lu et al. ( https://arxiv.org/abs/1906.00668 '))
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--fineSize', type=int, default=512, help='resize image to fineSize x fineSize,leave it to 0 if not resize')
parser.add_argument('--outf', default='samples/', help='folder to output images')
parser.add_argument('--alpha', type=float,default=1, help='hyperparameter to blend wct feature and content feature')
parser.add_argument('--targets', default=[5, 4, 3, 2, 1], nargs='+', help='which layers to stylize at. Order matters!')
parser.add_argument('--gamma', type=float,default=1, help='hyperparameter to blend original content feature and colorized features. See Wynen et al. 2018 eq. (3)')
parser.add_argument('--delta', type=float,default=1, help='hyperparameter to blend wct features from current input and original input. See Wynen et al. 2018 eq. (3)')
parser.add_argument('--gpu', type=int, default=0, help="which gpu to run on. default is 0")

args = parser.parse_args()
pprint.pprint(args.__dict__, indent=2)

try:
os.makedirs(args.outf)
Expand All @@ -45,70 +49,61 @@
batch_size=1,
shuffle=False)

wct = WCT(args)
def styleTransfer(contentImg,styleImg,imname,csF):

sF5 = wct.e5(styleImg)
cF5 = wct.e5(contentImg)
sF5 = sF5.data.cpu().squeeze(0)
cF5 = cF5.data.cpu().squeeze(0)
csF5 = wct.transform(cF5,sF5,csF,args.alpha)
Im5 = wct.d5(csF5)

sF4 = wct.e4(styleImg)
cF4 = wct.e4(Im5)
sF4 = sF4.data.cpu().squeeze(0)
cF4 = cF4.data.cpu().squeeze(0)
csF4 = wct.transform(cF4,sF4,csF,args.alpha)
Im4 = wct.d4(csF4)

sF3 = wct.e3(styleImg)
cF3 = wct.e3(Im4)
sF3 = sF3.data.cpu().squeeze(0)
cF3 = cF3.data.cpu().squeeze(0)
csF3 = wct.transform(cF3,sF3,csF,args.alpha)
Im3 = wct.d3(csF3)

sF2 = wct.e2(styleImg)
cF2 = wct.e2(Im3)
sF2 = sF2.data.cpu().squeeze(0)
cF2 = cF2.data.cpu().squeeze(0)
csF2 = wct.transform(cF2,sF2,csF,args.alpha)
Im2 = wct.d2(csF2)

sF1 = wct.e1(styleImg)
cF1 = wct.e1(Im2)
sF1 = sF1.data.cpu().squeeze(0)
cF1 = cF1.data.cpu().squeeze(0)
csF1 = wct.transform(cF1,sF1,csF,args.alpha)
Im1 = wct.d1(csF1)
def styleTransfer(wct, targets, contentImg, styleImg, imname, gamma, delta, outf, transform_method):

current_result = contentImg
eIorigs = [f.cpu().squeeze(0) for f in wct.encoder(contentImg, targets)]
eIss = [f.cpu().squeeze(0) for f in wct.encoder(styleImg, targets)]

for i, (target, eIorig, eIs) in enumerate(zip(targets, eIorigs, eIss)):
print(f' stylizing at {target}')

if i == 0:
eIlast = eIorig
else:
eIlast = wct.encoder(current_result, target).cpu().squeeze(0)

CsIlast = wct.transform(eIlast, eIs, transform_method).float()
CsIorig = wct.transform(eIorig, eIs, transform_method).float()

decoder_input = (gamma*(delta * CsIlast + (1-delta) * CsIorig) \
+ (1-gamma) * eIorig)
decoder_input = decoder_input.unsqueeze(0).to(next(wct.parameters()).device)

decoder = wct.decoders[target]
current_result = decoder(decoder_input)

# save_image has this wired design to pad images with 4 pixels at default.
vutils.save_image(Im1.data.cpu().float(),os.path.join(args.outf,imname))
return

avgTime = 0
cImg = torch.Tensor()
sImg = torch.Tensor()
csF = torch.Tensor()
csF = Variable(csF)
if(args.cuda):
cImg = cImg.cuda(args.gpu)
sImg = sImg.cuda(args.gpu)
csF = csF.cuda(args.gpu)
wct.cuda(args.gpu)
for i,(contentImg,styleImg,imname) in enumerate(loader):
imname = imname[0]
print('Transferring ' + imname)
if (args.cuda):
contentImg = contentImg.cuda(args.gpu)
styleImg = styleImg.cuda(args.gpu)
cImg = Variable(contentImg,volatile=True)
sImg = Variable(styleImg,volatile=True)
start_time = time.time()
# WCT Style Transfer
styleTransfer(cImg,sImg,imname,csF)
end_time = time.time()
print('Elapsed time is: %f' % (end_time - start_time))
avgTime += (end_time - start_time)

print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1)))
vutils.save_image(current_result.cpu().float(), os.path.join(outf,imname))
return current_result

def main():
wct = WCT(args)
if(args.cuda):
wct.cuda(args.gpu)

avgTime = 0
for i,(contentImg,styleImg,imname) in enumerate(loader):
if(args.cuda):
contentImg = contentImg.cuda(args.gpu)
styleImg = styleImg.cuda(args.gpu)
imname = imname[0]
print('\nTransferring ' + imname)
if (args.cuda):
contentImg = contentImg.cuda(args.gpu)
styleImg = styleImg.cuda(args.gpu)
start_time = time.time()
# WCT Style Transfer
targets = [f'relu{t}_1' for t in args.targets]
styleTransfer(wct, targets, contentImg, styleImg, imname,
args.gamma, args.delta, args.outf, args.transform_method)
end_time = time.time()
print(' Elapsed time is: %f' % (end_time - start_time))
avgTime += (end_time - start_time)

print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1)))

if __name__ == '__main__':
with torch.no_grad():

main()
Binary file added models/vgg19_normalized.pth.tar
Binary file not shown.
Binary file added models/vgg19_normalized_decoder1.pth.tar
Binary file not shown.
Binary file added models/vgg19_normalized_decoder2.pth.tar
Binary file not shown.
Binary file added models/vgg19_normalized_decoder3.pth.tar
Binary file not shown.
Binary file added models/vgg19_normalized_decoder4.pth.tar
Binary file not shown.
Binary file added models/vgg19_normalized_decoder5.pth.tar
Binary file not shown.
Loading