-
Notifications
You must be signed in to change notification settings - Fork 40
/
infer.py
158 lines (133 loc) · 6.24 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import json
import time
import argparse
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
from modules.FastSVC import SVCNN
from modules.wavlm_encoder import WavLMEncoder
from utils.pitch_ld_extraction import extract_loudness, extract_pitch_ref as extract_pitch
from utils.tools import ConfigWrapper, fast_cosine_dist
# the 6th layer features of wavlm are used to audio synthesis
SPEAKER_INFORMATION_LAYER = 6
# the mean of last 5 layers features of wavlm are used for matching in kNN
CONTENT_INFORMATION_LAYER = [20, 21, 22, 23, 24]
def VoiceConverter(test_utt: str, ref_utt: str, out_path: str, svc_mdl: SVCNN, wavlm_encoder: WavLMEncoder, f0_factor: float, speech_enroll=False, device=torch.device('cpu')):
"""
Perform singing voice conversion and save the resulting waveform to `out_path`.
Args:
test_utt (str): Path to the source singing waveform (24kHz, single-channel).
ref_utt (str): Path to the reference waveform from the target speaker (single-channel, not less than 16kHz).
out_path (str): Path to save the converted singing audio.
svc_mdl (SVCNN): Loaded FastSVC model with neural harmonic filters.
wavlm_encoder (WavLMEncoder): Loaded WavLM Encoder.
f0_factor (float): F0 shift factor.
speech_enroll (bool, optional): Whether the reference audio is a speech clip or a singing clip. Defaults to False.
device (torch.device, optional): Device to perform the conversion on. Defaults to cpu.
"""
# Preprocess audio and extract features.
print('Processing feats.')
applied_weights = F.one_hot(torch.tensor(CONTENT_INFORMATION_LAYER), num_classes=25).float().mean(axis=0).to(device)[:, None]
ld = extract_loudness(test_utt)
pitch, f0_factor = extract_pitch(test_utt, ref_utt, predefined_factor=f0_factor, speech_enroll=speech_enroll)
assert pitch.shape[0] == ld.shape[0], f'{test_utt} Length Mismatch: pitch length ({pitch.shape[0]}), ld length ({ld.shape[0]}).'
query_feats = wavlm_encoder.get_features(test_utt, weights=applied_weights)
matching_set = wavlm_encoder.get_features(ref_utt, weights=applied_weights)
synth_set = wavlm_encoder.get_features(ref_utt, output_layer=SPEAKER_INFORMATION_LAYER)
# Calculate the distance between the query feats and the matching feats
dists = fast_cosine_dist(query_feats, matching_set, device=device)
best = dists.topk(k=4, largest=False, dim=-1)
# Replace query features with corresponding nearest synth feats
prematched_wavlm = synth_set[best.indices].mean(dim=1).transpose(0, 1) # (T, 1024)
# Align the features: the hop_size of the wavlm feature is twice that of pitch and loudness.
seq_len = prematched_wavlm.shape[1] * 2
if seq_len > pitch.shape[0]:
p = seq_len - pitch.shape[0]
pitch = np.pad(pitch, (0, p), mode='edge')
ld = np.pad(ld, (0, p), mode='edge')
else:
pitch = pitch[:seq_len]
ld = ld[:seq_len]
in_feats = [prematched_wavlm.unsqueeze(0), torch.from_numpy(pitch).to(dtype=torch.float).unsqueeze(0),
torch.from_numpy(ld).to(dtype=torch.float).unsqueeze(0)]
in_feats = tuple([x_.to(device) for x_ in in_feats])
# Inference
print('Inferencing.')
with torch.no_grad():
y_ = svc_mdl(*in_feats)
# Save converted audio.
print('Saving audio.')
os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
y_ = y_.unsqueeze(0)
y_ = np.clip(y_.view(-1).cpu().numpy(), -1, 1)
sf.write(out_path, y_, 24000)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--src_wav_path',
required=True, type=str, help='The audio path for the source singing utterance.'
)
parser.add_argument(
'--ref_wav_path',
required=True, type=str, help='The audio path for the reference utterance.'
)
parser.add_argument(
'--out_path',
required=True, type=str, help='The audio path for the reference utterance.'
)
parser.add_argument(
'-cfg', '--config_file',
type=str, default='configs/config.json',
help='The model configuration file.'
)
parser.add_argument(
'-ckpt', '--ckpt_path',
type=str,
default='pretrained/model.pkl',
help='The model checkpoint path for loading.'
)
parser.add_argument(
'-f0factor', '--f0_factor', type=float, default=0.0,
help='Adjust the pitch of the source singing to match the vocal range of the target singer. \
The default value is 0.0, which means no pitch adjustment is applied (equivalent to f0_factor = 1.0)'
)
parser.add_argument(
'--speech_enroll', action='store_true',
help='When using speech as the reference audio, the pitch of the reference audio will be increased by 1.2 times \
when performing pitch shift to cover the pitch gap between singing and speech. \
Note: This option is invalid when f0_factor is specified.'
)
args = parser.parse_args()
t0 = time.time()
f0factor = args.f0_factor
speech_enroll_flag = args.speech_enroll
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f'using {device} for inference.')
# Loading model and parameters.
# load svc model
cfg = args.config_file
model_path = args.ckpt_path
print('Loading svc model configurations.')
with open(cfg) as f:
config = ConfigWrapper(**json.load(f))
svc_mdl = SVCNN(config)
state_dict = torch.load(model_path, map_location='cpu')
svc_mdl.load_state_dict(state_dict['model']['generator'], strict=False)
svc_mdl.to(device)
svc_mdl.eval()
# load wavlm model
wavlm_encoder = WavLMEncoder(ckpt_path='pretrained/WavLM-Large.pt', device=device)
print('wavlm loaded.')
# End loading model and parameters.
t1 = time.time()
print(f'loading models cost {t1-t0:.2f}s.')
VoiceConverter(test_utt=args.src_wav_path, ref_utt=args.ref_wav_path, out_path=args.out_path,
svc_mdl=svc_mdl, wavlm_encoder=wavlm_encoder,
f0_factor=f0factor, speech_enroll=speech_enroll_flag, device=device)
t2 = time.time()
print(f'converting costs {t2-t1:.2f}s.')