Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
NaruseMioShirakana authored Aug 11, 2022
1 parent 74b65b8 commit 6ad7786
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions SubProject VITSLJS-Libtorch/VITS-LibTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ namespace Shirakana {

int main(int argc,char* argv[])
{
if (argc != 5) {
if (argc != 6 && argc != 7) {
return 0;
}
wchar_t* buffer;
Expand All @@ -76,18 +76,16 @@ int main(int argc,char* argv[])
std::wstring Path = bufferStr + L"\\" + PathTmp;
std::wstring TextInput = string2wstring(argv[2]);
std::wstring OutDir = string2wstring(argv[3]);
std::wstring HifiganPath = bufferStr + L"\\hifigan.onnx";
std::wstring SymbolStr = string2wstring(argv[4]);
std::string Mode = argv[5];
std::map<wchar_t, int64> Symbol;
for (size_t i = 0; i < SymbolStr.length(); i++) {
Symbol.insert(std::pair<wchar_t, int64>(SymbolStr[i], (int64)(i)));
}
std::vector<int64> text;
for (size_t i = 0; i < TextInput.length(); i++) {
text.push_back(0);
std::cout << text[text.size() - 1] << " ";
text.push_back(Symbol[TextInput[i]]);
std::cout << text[text.size() - 1] << " ";
}
text.push_back(0);
auto InputTensor = torch::from_blob(text.data(), { 1, (long long)text.size() }, torch::kInt64);
Expand All @@ -97,12 +95,31 @@ int main(int argc,char* argv[])
std::vector<torch::IValue> inputs;
inputs.push_back(InputTensor);
inputs.push_back(InputTensor_length);
try {
VITSMODULE = torch::jit::load(Shirakana::to_byte_string(PathTmp) + "_LJS.pt");
if (Mode == "LJS") {
try {
VITSMODULE = torch::jit::load(Shirakana::to_byte_string(PathTmp) + "_LJS.pt");
}
catch (c10::Error e) {
std::cout << e.what();
return 0;
}
}
catch (c10::Error e) {
std::cout << e.what() << "\ncheckpoint1";
return 0;
else if(Mode == "VCTK") {
if (argc != 7) {
return 0;
}
try {
std::array<int64, 1> speakerIndex{ (int64)atoi(argv[6]) };
VITSMODULE = torch::jit::load(Shirakana::to_byte_string(PathTmp) + "_VCTK.pt");
inputs.push_back(torch::from_blob(speakerIndex.data(), { 1 }, torch::kLong));
}
catch (c10::Error e) {
std::cout << e.what();
return 0;
}
}
else {
std::cout << "模式错误";
}
try {
auto output = VITSMODULE.forward(inputs).toTuple()->elements()[0].toTensor().multiply(32276.0F);
Expand All @@ -119,7 +136,7 @@ int main(int argc,char* argv[])
return Shirakana::conArr2Wav(outputSize, outputTmp, filenames.c_str());
}
catch (c10::Error e) {
std::cout << e.what() << "\ncheckpoint2";
std::cout << e.what();
return 0;
}
}

0 comments on commit 6ad7786

Please sign in to comment.