-
Notifications
You must be signed in to change notification settings - Fork 145
/
infer_path.py
65 lines (54 loc) · 3.36 KB
/
infer_path.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import functools
import time
from data_utils.audio_process import AudioInferProcess
from utils.predict import Predictor
from utils.audio_vad import crop_audio_vad
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('wav_path', str, './dataset/test.wav', "预测音频的路径")
add_arg('is_long_audio', bool, False, "是否为长语音")
add_arg('use_gpu', bool, True, "是否使用GPU预测")
add_arg('enable_mkldnn', bool, False, "是否使用mkldnn加速")
add_arg('to_an', bool, True, "是否转为阿拉伯数字")
add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索的大小,范围:[5, 500]")
add_arg('alpha', float, 1.2, "集束搜索解码相关参数,LM系数")
add_arg('beta', float, 0.35, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('model_dir', str, './models/infer/', "导出的预测模型文件夹路径")
add_arg('lang_model_path', str, './lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法,有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)", choices=['ctc_beam_search', 'ctc_greedy'])
args = parser.parse_args()
print_arguments(args)
# 获取数据生成器,处理数据和获取字典需要
audio_process = AudioInferProcess(vocab_filepath=args.vocab_path, mean_std_filepath=args.mean_std_path)
predictor = Predictor(model_dir=args.model_dir, audio_process=audio_process, decoding_method=args.decoding_method,
alpha=args.alpha, beta=args.beta, lang_model_path=args.lang_model_path, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n, use_gpu=args.use_gpu,
enable_mkldnn=args.enable_mkldnn)
def predict_long_audio():
start = time.time()
# 分割长音频
audios_path = crop_audio_vad(args.wav_path)
texts = ''
scores = []
# 执行识别
for i, audio_path in enumerate(audios_path):
score, text = predictor.predict(audio_path=audio_path, to_an=args.to_an)
texts = texts + ',' + text
scores.append(score)
print("第%d个分割音频, 得分: %d, 识别结果: %s" % (i, score, text))
print("最终结果,消耗时间:%d, 得分: %d, 识别结果: %s" % (round((time.time() - start) * 1000), sum(scores) / len(scores), texts))
def predict_audio():
start = time.time()
score, text = predictor.predict(audio_path=args.wav_path, to_an=args.to_an)
print("消耗时间:%dms, 识别结果: %s, 得分: %d" % (round((time.time() - start) * 1000), text, score))
if __name__ == "__main__":
if args.is_long_audio:
predict_long_audio()
else:
predict_audio()