-
Notifications
You must be signed in to change notification settings - Fork 387
/
demo.py
113 lines (88 loc) · 2.71 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
demo.py - Optimized style transfer pipeline for interactive demo.
"""
# system imports
import argparse
import logging
from threading import Lock
import time
# library imports
import caffe
import cv2
from skimage.transform import rescale
# local imports
from style import StyleTransfer
# argparse
parser = argparse.ArgumentParser(description="Run the optimized style transfer pipeline.",
usage="demo.py -s <style_image> -c <content_image>")
parser.add_argument("-s", "--style-img", type=str, required=True, help="input style (art) image")
parser.add_argument("-c", "--content-img", type=str, required=True, help="input content image")
# style transfer
# style workers, each should be backed by a lock
workers = {}
def gpu_count():
"""
Counts the number of CUDA-capable GPUs (Linux only).
"""
# use nvidia-smi to count number of GPUs
try:
output = subprocess.check_output("nvidia-smi -L")
return len(output.strip().split("\n"))
except:
return 0
def init(n_workers=1):
"""
Initialize the style transfer backend.
"""
global workers
if n_workers == 0:
n_workers = 1
# assign a lock to each worker
for i in range(n_workers):
worker = StyleTransfer("googlenet", use_pbar=False)
workers.update({Lock(): worker})
def st_api(img_style, img_content, callback=None):
"""
Style transfer API.
"""
global workers
# style transfer arguments
all_args = [{"length": 360, "ratio": 2e3, "n_iter": 32, "callback": callback},
{"length": 512, "ratio": 2e4, "n_iter": 16, "callback": callback}]
# acquire a worker (non-blocking)
st_lock = None
st_worker = None
while st_lock is None:
for lock, worker in workers.iteritems():
# unblocking acquire
if lock.acquire(False):
st_lock = lock
st_worker = worker
break
else:
time.sleep(0.1)
# start style transfer
img_out = "content"
for args in all_args:
args["init"] = img_out
st_worker.transfer_style(img_style, img_content, **args)
img_out = st_worker.get_generated()
st_lock.release()
return img_out
def main(args):
"""
Entry point.
"""
# spin up a worker
init()
# perform style transfer
img_style = caffe.io.load_image(args.style_img)
img_content = caffe.io.load_image(args.content_img)
result = st_api(img_style, img_content)
# show the image
cv2.imshow("Art", cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
cv2.waitKey()
cv2.destroyWindow("Art")
if __name__ == "__main__":
args = parser.parse_args()
main(args)