-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: whisper addon support more configuration options
- Loading branch information
Showing
2 changed files
with
345 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,344 @@ | ||
#include "napi.h" | ||
#include "common.h" | ||
|
||
#include "whisper.h" | ||
|
||
#include <string> | ||
#include <thread> | ||
#include <vector> | ||
#include <cmath> | ||
#include <cstdint> | ||
|
||
struct whisper_params { | ||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); | ||
int32_t n_processors = 1; | ||
int32_t offset_t_ms = 0; | ||
int32_t offset_n = 0; | ||
int32_t duration_ms = 0; | ||
int32_t max_context = -1; | ||
int32_t max_len = 0; | ||
int32_t best_of = 5; | ||
int32_t beam_size = -1; | ||
|
||
float word_thold = 0.01f; | ||
float entropy_thold = 2.4f; | ||
float logprob_thold = -1.0f; | ||
|
||
bool speed_up = false; | ||
bool translate = false; | ||
bool diarize = false; | ||
bool output_txt = false; | ||
bool output_vtt = false; | ||
bool output_srt = false; | ||
bool output_wts = false; | ||
bool output_csv = false; | ||
bool print_special = false; | ||
bool print_colors = false; | ||
bool print_progress = false; | ||
bool no_timestamps = false; | ||
|
||
std::string language = "en"; | ||
std::string prompt; | ||
std::string model = "../../ggml-large.bin"; | ||
|
||
std::vector<std::string> fname_inp = {}; | ||
std::vector<std::string> fname_out = {}; | ||
}; | ||
|
||
struct whisper_print_user_data { | ||
const whisper_params * params; | ||
|
||
const std::vector<std::vector<float>> * pcmf32s; | ||
}; | ||
|
||
// 500 -> 00:05.000 | ||
// 6000 -> 01:00.000 | ||
std::string to_timestamp(int64_t t, bool comma = false) { | ||
int64_t msec = t * 10; | ||
int64_t hr = msec / (1000 * 60 * 60); | ||
msec = msec - hr * (1000 * 60 * 60); | ||
int64_t min = msec / (1000 * 60); | ||
msec = msec - min * (1000 * 60); | ||
int64_t sec = msec / 1000; | ||
msec = msec - sec * 1000; | ||
|
||
char buf[32]; | ||
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); | ||
|
||
return std::string(buf); | ||
} | ||
|
||
int timestamp_to_sample(int64_t t, int n_samples) { | ||
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); | ||
} | ||
|
||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { | ||
const auto & params = *((whisper_print_user_data *) user_data)->params; | ||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; | ||
|
||
const int n_segments = whisper_full_n_segments(ctx); | ||
|
||
std::string speaker = ""; | ||
|
||
int64_t t0; | ||
int64_t t1; | ||
|
||
// print the last n_new segments | ||
const int s0 = n_segments - n_new; | ||
|
||
if (s0 == 0) { | ||
printf("\n"); | ||
} | ||
|
||
for (int i = s0; i < n_segments; i++) { | ||
if (!params.no_timestamps || params.diarize) { | ||
t0 = whisper_full_get_segment_t0(ctx, i); | ||
t1 = whisper_full_get_segment_t1(ctx, i); | ||
} | ||
|
||
if (!params.no_timestamps) { | ||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); | ||
} | ||
|
||
if (params.diarize && pcmf32s.size() == 2) { | ||
const int64_t n_samples = pcmf32s[0].size(); | ||
|
||
const int64_t is0 = timestamp_to_sample(t0, n_samples); | ||
const int64_t is1 = timestamp_to_sample(t1, n_samples); | ||
|
||
double energy0 = 0.0f; | ||
double energy1 = 0.0f; | ||
|
||
for (int64_t j = is0; j < is1; j++) { | ||
energy0 += fabs(pcmf32s[0][j]); | ||
energy1 += fabs(pcmf32s[1][j]); | ||
} | ||
|
||
if (energy0 > 1.1*energy1) { | ||
speaker = "(speaker 0)"; | ||
} else if (energy1 > 1.1*energy0) { | ||
speaker = "(speaker 1)"; | ||
} else { | ||
speaker = "(speaker ?)"; | ||
} | ||
|
||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); | ||
} | ||
|
||
// colorful print bug | ||
// | ||
const char * text = whisper_full_get_segment_text(ctx, i); | ||
printf("%s%s", speaker.c_str(), text); | ||
|
||
|
||
// with timestamps or speakers: each segment on new line | ||
if (!params.no_timestamps || params.diarize) { | ||
printf("\n"); | ||
} | ||
|
||
fflush(stdout); | ||
} | ||
} | ||
|
||
int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) { | ||
if (params.fname_inp.empty()) { | ||
fprintf(stderr, "error: no input files specified\n"); | ||
return 2; | ||
} | ||
|
||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { | ||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); | ||
exit(0); | ||
} | ||
|
||
// whisper init | ||
|
||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); | ||
|
||
if (ctx == nullptr) { | ||
fprintf(stderr, "error: failed to initialize whisper context\n"); | ||
return 3; | ||
} | ||
|
||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { | ||
const auto fname_inp = params.fname_inp[f]; | ||
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; | ||
|
||
std::vector<float> pcmf32; // mono-channel F32 PCM | ||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM | ||
|
||
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { | ||
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); | ||
continue; | ||
} | ||
|
||
// print system information | ||
{ | ||
fprintf(stderr, "\n"); | ||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", | ||
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); | ||
} | ||
|
||
// print some info about the processing | ||
{ | ||
fprintf(stderr, "\n"); | ||
if (!whisper_is_multilingual(ctx)) { | ||
if (params.language != "en" || params.translate) { | ||
params.language = "en"; | ||
params.translate = false; | ||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); | ||
} | ||
} | ||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", | ||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, | ||
params.n_threads, params.n_processors, | ||
params.language.c_str(), | ||
params.translate ? "translate" : "transcribe", | ||
params.no_timestamps ? 0 : 1); | ||
|
||
fprintf(stderr, "\n"); | ||
} | ||
|
||
// run the inference | ||
{ | ||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); | ||
|
||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; | ||
|
||
wparams.print_realtime = false; | ||
wparams.print_progress = params.print_progress; | ||
wparams.print_timestamps = !params.no_timestamps; | ||
wparams.print_special = params.print_special; | ||
wparams.translate = params.translate; | ||
wparams.language = params.language.c_str(); | ||
wparams.n_threads = params.n_threads; | ||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; | ||
wparams.offset_ms = params.offset_t_ms; | ||
wparams.duration_ms = params.duration_ms; | ||
|
||
wparams.token_timestamps = params.output_wts || params.max_len > 0; | ||
wparams.thold_pt = params.word_thold; | ||
wparams.entropy_thold = params.entropy_thold; | ||
wparams.logprob_thold = params.logprob_thold; | ||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; | ||
|
||
wparams.speed_up = params.speed_up; | ||
|
||
wparams.greedy.best_of = params.best_of; | ||
wparams.beam_search.beam_size = params.beam_size; | ||
|
||
wparams.initial_prompt = params.prompt.c_str(); | ||
|
||
whisper_print_user_data user_data = { ¶ms, &pcmf32s }; | ||
|
||
// this callback is called on each new segment | ||
if (!wparams.print_realtime) { | ||
wparams.new_segment_callback = whisper_print_segment_callback; | ||
wparams.new_segment_callback_user_data = &user_data; | ||
} | ||
|
||
// example for abort mechanism | ||
// in this example, we do not abort the processing, but we could if the flag is set to true | ||
// the callback is called before every encoder run - if it returns false, the processing is aborted | ||
{ | ||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race | ||
|
||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { | ||
bool is_aborted = *(bool*)user_data; | ||
return !is_aborted; | ||
}; | ||
wparams.encoder_begin_callback_user_data = &is_aborted; | ||
} | ||
|
||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { | ||
fprintf(stderr, "failed to process audio\n"); | ||
return 10; | ||
} | ||
} | ||
} | ||
|
||
const int n_segments = whisper_full_n_segments(ctx); | ||
result.resize(n_segments); | ||
for (int i = 0; i < n_segments; ++i) { | ||
const char * text = whisper_full_get_segment_text(ctx, i); | ||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); | ||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); | ||
|
||
result[i].emplace_back(to_timestamp(t0, true)); | ||
result[i].emplace_back(to_timestamp(t1, true)); | ||
result[i].emplace_back(text); | ||
} | ||
|
||
whisper_print_timings(ctx); | ||
whisper_free(ctx); | ||
|
||
return 0; | ||
} | ||
|
||
class Worker : public Napi::AsyncWorker { | ||
public: | ||
Worker(Napi::Function& callback, whisper_params params) | ||
: Napi::AsyncWorker(callback), params(params) {} | ||
|
||
void Execute() override { | ||
run(params, result); | ||
} | ||
|
||
void OnOK() override { | ||
Napi::HandleScope scope(Env()); | ||
Napi::Object res = Napi::Array::New(Env(), result.size()); | ||
for (uint64_t i = 0; i < result.size(); ++i) { | ||
Napi::Object tmp = Napi::Array::New(Env(), 3); | ||
for (uint64_t j = 0; j < 3; ++j) { | ||
tmp[j] = Napi::String::New(Env(), result[i][j]); | ||
} | ||
res[i] = tmp; | ||
} | ||
Callback().Call({Env().Null(), res}); | ||
} | ||
|
||
private: | ||
whisper_params params; | ||
std::vector<std::vector<std::string>> result; | ||
}; | ||
|
||
|
||
|
||
Napi::Value whisper(const Napi::CallbackInfo& info) { | ||
Napi::Env env = info.Env(); | ||
if (info.Length() <= 0 || !info[0].IsObject()) { | ||
Napi::TypeError::New(env, "object expected").ThrowAsJavaScriptException(); | ||
} | ||
whisper_params params; | ||
|
||
Napi::Object whisper_params = info[0].As<Napi::Object>(); | ||
std::string language = whisper_params.Get("language").As<Napi::String>(); | ||
std::string model = whisper_params.Get("model").As<Napi::String>(); | ||
std::string input = whisper_params.Get("fname_inp").As<Napi::String>(); | ||
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>(); | ||
bool translate = whisper_params.Get("translate").As<Napi::Boolean>(); | ||
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>(); | ||
|
||
params.language = language; | ||
params.model = model; | ||
params.fname_inp.emplace_back(input); | ||
params.max_len = max_len; | ||
params.translate = translate; | ||
params.no_timestamps = no_timestamps; | ||
|
||
Napi::Function callback = info[1].As<Napi::Function>(); | ||
Worker* worker = new Worker(callback, params); | ||
worker->Queue(); | ||
return env.Undefined(); | ||
} | ||
|
||
|
||
Napi::Object Init(Napi::Env env, Napi::Object exports) { | ||
exports.Set( | ||
Napi::String::New(env, "whisper"), | ||
Napi::Function::New(env, whisper) | ||
); | ||
return exports; | ||
} | ||
|
||
NODE_API_MODULE(whisper, Init); |