forked from Rayhane-mamah/Tacotron-2
-
Notifications
You must be signed in to change notification settings - Fork 47
/
demo_server.py
164 lines (149 loc) · 4.94 KB
/
demo_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import chardet
import thriftpy
import falcon
import tensorflow as tf
import numpy as np
import io
import re
import os
import json
import urllib
from datasets import audio
from mainstay import Mainstay
from hparams import hparams
from infolog import log
from tacotron.synthesizer import Synthesizer
from wsgiref import simple_server
from pypinyin import pinyin, lazy_pinyin, Style
html_body = '''<html><title>Tacotron-2 Demo</title><meta charset='utf-8'>
<style>
body {padding: 16px; font-family: sans-serif; font-size: 14px; color: #444}
input {font-size: 14px; padding: 8px 12px; outline: none; border: 1px solid #ddd}
input:focus {box-shadow: 0 1px 2px rgba(0,0,0,.15)}
p {padding: 12px}
button {background: #28d; padding: 9px 14px; margin-left: 8px; border: none; outline: none;
color: #fff; font-size: 14px; border-radius: 4px; cursor: pointer;}
button:hover {box-shadow: 0 1px 2px rgba(0,0,0,.15); opacity: 0.9;}
button:active {background: #29f;}
button[disabled] {opacity: 0.4; cursor: default}
</style>
<body>
<form>
<input id="text" type="text" size="40" placeholder="请输入文字">
<button id="button" name="synthesize">合成</button>
</form>
<p id="message"></p>
<audio id="audio" controls autoplay hidden></audio>
<script>
function q(selector) {return document.querySelector(selector)}
q('#text').focus()
q('#button').addEventListener('click', function(e) {
text = q('#text').value.trim()
if (text) {
q('#message').textContent = '合成中...'
q('#button').disabled = true
q('#audio').hidden = true
synthesize(text)
}
e.preventDefault()
return false
})
function synthesize(text) {
fetch('/synthesize?text=' + encodeURIComponent(text), {cache: 'no-cache'})
.then(function(res) {
if (!res.ok) throw Error(res.statusText)
return res.blob()
}).then(function(blob) {
q('#message').textContent = ''
q('#button').disabled = false
q('#audio').src = URL.createObjectURL(blob)
q('#audio').hidden = false
}).catch(function(err) {
q('#message').textContent = '出错: ' + err.message
q('#button').disabled = false
})
}
</script></body></html>
'''
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', default='pretrained/', help='Path to model checkpoint')
parser.add_argument('--hparams', default='',help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--port', default=6006,help='Port of Http service')
parser.add_argument('--host', default="localhost",help='Host of Http service')
parser.add_argument('--name', help='Name of logging directory if the two models were trained together.')
args = parser.parse_args()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
checkpoint = os.path.join('logs-Tacotron', 'taco_' + args.checkpoint)
try:
checkpoint_path = tf.train.get_checkpoint_state(checkpoint).model_checkpoint_path
log('loaded model at {}'.format(checkpoint_path))
except:
raise RuntimeError('Failed to load checkpoint at {}'.format(checkpoint))
synth = Synthesizer()
modified_hp = hparams.parse(args.hparams)
synth.load(checkpoint_path, modified_hp)
class Res:
def on_get(self,req,res):
res.body = html_body
res.content_type = "text/html"
class Syn:
def on_get(self,req,res):
if not req.params.get('text'):
raise falcon.HTTPBadRequest()
orig_chs = req.params.get('text')
norm_chs = chs_norm(orig_chs)
print(norm_chs.encode("utf-8").decode("utf-8"))
pys = chs_pinyin(norm_chs)
out = io.BytesIO()
wav = synth.eval(pys)
audio.save_wav(wav, out, hparams)
res.data = out.getvalue()
res.content_type = "audio/wav"
def chs_pinyin(text):
pys = pinyin(text, style=Style.TONE3)
results = []
sentence = []
for i in range(len(pys)):
if pys[i][0][0] in ",、·,":
pys[i][0] = ','
elif pys[i][0][0] in ".。…":
pys[i][0] = '.'
elif pys[i][0][0] in "―――———":
pys[i][0] = ','
elif pys[i][0][0] in ";::;":
pys[i][0] = ','
elif pys[i][0][0] in "??":
pys[i][0] = '?'
elif pys[i][0][0] in "!!":
pys[i][0] = '!'
elif pys[i][0][0] in "《》()()":
continue
elif pys[i][0][0] in "“”‘’"\"\'":
continue
elif pys[i][0][0] in " /<>「」":
continue
sentence.append(pys[i][0])
if pys[i][0] in ",.;?!:":
results.append(' '.join(sentence))
sentence = []
if len(sentence) > 0:
results.append(' '.join(sentence))
for i, res in enumerate(results):
if results[i][-1] not in ",.":
results[i] += ' .'
print(res)
return results
def chs_norm(text):
url = 'http://search.ximalaya.com/text-format/numberFormat/convert'
payload = json.dumps(list(text)).encode()
request = urllib.request.Request(url, payload)
request.add_header("Content-Type",'application/json')
responese = urllib.request.urlopen(request)
return ''.join(json.loads(responese.read().decode()))
api = falcon.API()
api.add_route("/",Res())
api.add_route("/synthesize",Syn())
log("host:{},port:{}".format(args.host,int(args.port)))
simple_server.make_server(args.host,int(args.port),api).serve_forever()