-
Notifications
You must be signed in to change notification settings - Fork 108
/
run_main.py
162 lines (130 loc) · 6.27 KB
/
run_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import tensorflow as tf
import numpy as np
import utils
import vgg19
import style_transfer
import os
import argparse
"""parsing and configuration"""
def parse_args():
desc = "Tensorflow implementation of 'Image Style Transfer Using Convolutional Neural Networks"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--model_path', type=str, default='pre_trained_model', help='The directory where the pre-trained model was saved')
parser.add_argument('--content', type=str, default='images/tubingen.jpg', help='File path of content image (notation in the paper : p)', required = True)
parser.add_argument('--style', type=str, default='images/starry-night.jpg', help='File path of style image (notation in the paper : a)', required = True)
parser.add_argument('--output', type=str, default='result.jpg', help='File path of output image', required = True)
parser.add_argument('--loss_ratio', type=float, default=1e-3, help='Weight of content-loss relative to style-loss')
parser.add_argument('--content_layers', nargs='+', type=str, default=['conv4_2'], help='VGG19 layers used for content loss')
parser.add_argument('--style_layers', nargs='+', type=str, default=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'],
help='VGG19 layers used for style loss')
parser.add_argument('--content_layer_weights', nargs='+', type=float, default=[1.0], help='Content loss for each content is multiplied by corresponding weight')
parser.add_argument('--style_layer_weights', nargs='+', type=float, default=[.2,.2,.2,.2,.2],
help='Style loss for each content is multiplied by corresponding weight')
parser.add_argument('--initial_type', type=str, default='content', choices=['random','content','style'], help='The initial image for optimization (notation in the paper : x)')
parser.add_argument('--max_size', type=int, default=512, help='The maximum width or height of input images')
parser.add_argument('--content_loss_norm_type', type=int, default=3, choices=[1,2,3], help='Different types of normalization for content loss')
parser.add_argument('--num_iter', type=int, default=1000, help='The number of iterations to run')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
try:
assert len(args.content_layers) == len(args.content_layer_weights)
except:
print ('content layer info and weight info must be matched')
return None
try:
assert len(args.style_layers) == len(args.style_layer_weights)
except:
print('style layer info and weight info must be matched')
return None
try:
assert args.max_size > 100
except:
print ('Too small size')
return None
model_file_path = args.model_path + '/' + vgg19.MODEL_FILE_NAME
try:
assert os.path.exists(model_file_path)
except:
print ('There is no %s'%model_file_path)
return None
try:
size_in_KB = os.path.getsize(model_file_path)
assert abs(size_in_KB - 534904783) < 10
except:
print('check file size of \'imagenet-vgg-verydeep-19.mat\'')
print('there are some files with the same name')
print('pre_trained_model used here can be downloaded from bellow')
print('http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat')
return None
try:
assert os.path.exists(args.content)
except:
print('There is no %s'%args.content)
return None
try:
assert os.path.exists(args.style)
except:
print('There is no %s' % args.style)
return None
return args
"""add one dim for batch"""
# VGG19 requires input dimension to be (batch, height, width, channel)
def add_one_dim(image):
shape = (1,) + image.shape
return np.reshape(image, shape)
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
# initiate VGG19 model
model_file_path = args.model_path + '/' + vgg19.MODEL_FILE_NAME
vgg_net = vgg19.VGG19(model_file_path)
# load content image and style image
content_image = utils.load_image(args.content, max_size=args.max_size)
style_image = utils.load_image(args.style, shape=(content_image.shape[1],content_image.shape[0]))
# initial guess for output
if args.initial_type == 'content':
init_image = content_image
elif args.initial_type == 'style':
init_image = style_image
elif args.initial_type == 'random':
init_image = np.random.normal(size=content_image.shape, scale=np.std(content_image))
# check input images for style-transfer
# utils.plot_images(content_image,style_image, init_image)
# create a map for content layers info
CONTENT_LAYERS = {}
for layer, weight in zip(args.content_layers,args.content_layer_weights):
CONTENT_LAYERS[layer] = weight
# create a map for style layers info
STYLE_LAYERS = {}
for layer, weight in zip(args.style_layers, args.style_layer_weights):
STYLE_LAYERS[layer] = weight
# open session
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# build the graph
st = style_transfer.StyleTransfer(session = sess,
content_layer_ids = CONTENT_LAYERS,
style_layer_ids = STYLE_LAYERS,
init_image = add_one_dim(init_image),
content_image = add_one_dim(content_image),
style_image = add_one_dim(style_image),
net = vgg_net,
num_iter = args.num_iter,
loss_ratio = args.loss_ratio,
content_loss_norm_type = args.content_loss_norm_type,
)
# launch the graph in a session
result_image = st.update()
# close session
sess.close()
# remove batch dimension
shape = result_image.shape
result_image = np.reshape(result_image,shape[1:])
# save result
utils.save_image(result_image,args.output)
# utils.plot_images(content_image,style_image, result_image)
if __name__ == '__main__':
main()