diff --git a/SubProject VITSLJS-Libtorch/VITS-LibTorch.cpp b/SubProject VITSLJS-Libtorch/VITS-LibTorch.cpp index 6ee757f..2ce8c4b 100644 --- a/SubProject VITSLJS-Libtorch/VITS-LibTorch.cpp +++ b/SubProject VITSLJS-Libtorch/VITS-LibTorch.cpp @@ -63,7 +63,7 @@ namespace Shirakana { int main(int argc,char* argv[]) { - if (argc != 5) { + if (argc != 6 && argc != 7) { return 0; } wchar_t* buffer; @@ -76,8 +76,8 @@ int main(int argc,char* argv[]) std::wstring Path = bufferStr + L"\\" + PathTmp; std::wstring TextInput = string2wstring(argv[2]); std::wstring OutDir = string2wstring(argv[3]); - std::wstring HifiganPath = bufferStr + L"\\hifigan.onnx"; std::wstring SymbolStr = string2wstring(argv[4]); + std::string Mode = argv[5]; std::map Symbol; for (size_t i = 0; i < SymbolStr.length(); i++) { Symbol.insert(std::pair(SymbolStr[i], (int64)(i))); @@ -85,9 +85,7 @@ int main(int argc,char* argv[]) std::vector text; for (size_t i = 0; i < TextInput.length(); i++) { text.push_back(0); - std::cout << text[text.size() - 1] << " "; text.push_back(Symbol[TextInput[i]]); - std::cout << text[text.size() - 1] << " "; } text.push_back(0); auto InputTensor = torch::from_blob(text.data(), { 1, (long long)text.size() }, torch::kInt64); @@ -97,12 +95,31 @@ int main(int argc,char* argv[]) std::vector inputs; inputs.push_back(InputTensor); inputs.push_back(InputTensor_length); - try { - VITSMODULE = torch::jit::load(Shirakana::to_byte_string(PathTmp) + "_LJS.pt"); + if (Mode == "LJS") { + try { + VITSMODULE = torch::jit::load(Shirakana::to_byte_string(PathTmp) + "_LJS.pt"); + } + catch (c10::Error e) { + std::cout << e.what(); + return 0; + } } - catch (c10::Error e) { - std::cout << e.what() << "\ncheckpoint1"; - return 0; + else if(Mode == "VCTK") { + if (argc != 7) { + return 0; + } + try { + std::array speakerIndex{ (int64)atoi(argv[6]) }; + VITSMODULE = torch::jit::load(Shirakana::to_byte_string(PathTmp) + "_VCTK.pt"); + inputs.push_back(torch::from_blob(speakerIndex.data(), { 1 }, torch::kLong)); + } + catch (c10::Error e) { + std::cout << e.what(); + return 0; + } + } + else { + std::cout << "模式错误"; } try { auto output = VITSMODULE.forward(inputs).toTuple()->elements()[0].toTensor().multiply(32276.0F); @@ -119,7 +136,7 @@ int main(int argc,char* argv[]) return Shirakana::conArr2Wav(outputSize, outputTmp, filenames.c_str()); } catch (c10::Error e) { - std::cout << e.what() << "\ncheckpoint2"; + std::cout << e.what(); return 0; } }