Skip to content

Commit

Permalink
Merge pull request #6 from nrl-ai/hao_search_emb
Browse files Browse the repository at this point in the history
Embeding Search
  • Loading branch information
vietanhdev authored Aug 2, 2023
2 parents 9311d56 + b126f6f commit b9f14c9
Show file tree
Hide file tree
Showing 47 changed files with 3,055 additions and 316 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ find_package(OpenCV REQUIRED)
add_subdirectory(libs/whisper-cpp)
add_subdirectory(libs/subprocess)

# Build embeding search module
add_subdirectory(customchar/embeddb)

# Build CustomChar-core
set(TARGET customchar-core)
add_library(
Expand All @@ -49,6 +52,9 @@ target_include_directories(
)
target_link_libraries(${TARGET} PUBLIC ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} ${OpenCV_LIBS} whisper subprocess)

# Set binary output directory
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

# CustomChar - cli
add_executable(
customchar-cli
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ doxygen Doxyfile.in

We welcome all contributions to this project.

- For coding style, please follow the style of the existing code. We basically follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
- For coding style, please follow the style of the existing code.
- Install [clang-format](https://clang.llvm.org/docs/ClangFormat.html) for auto formatting the code.
- Install [pre-commit](https://pre-commit.com/) for the auto-formatting hook or manually run the script `scripts/format-code.sh` to format the code.

Expand Down
17 changes: 9 additions & 8 deletions customchar/audio/audio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
using namespace CC;
using namespace CC::audio;

bool CC::audio::ReadWav(const std::string& fname, std::vector<float>& pcmf32,
std::vector<std::vector<float>>& pcmf32s, bool stereo) {
bool CC::audio::read_wav(const std::string& fname, std::vector<float>& pcmf32,
std::vector<std::vector<float>>& pcmf32s,
bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin

Expand Down Expand Up @@ -114,8 +115,8 @@ bool CC::audio::ReadWav(const std::string& fname, std::vector<float>& pcmf32,
return true;
}

void CC::audio::HighPassFilter(std::vector<float>& data, float cutoff,
float sample_rate) {
void CC::audio::high_pass_filter(std::vector<float>& data, float cutoff,
float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
Expand All @@ -128,9 +129,9 @@ void CC::audio::HighPassFilter(std::vector<float>& data, float cutoff,
}
}

bool CC::audio::VADSimple(std::vector<float>& pcmf32, int sample_rate,
int last_ms, float vad_thold, float freq_thold,
bool verbose) {
bool CC::audio::vad_simple(std::vector<float>& pcmf32, int sample_rate,
int last_ms, float vad_thold, float freq_thold,
bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;

Expand All @@ -140,7 +141,7 @@ bool CC::audio::VADSimple(std::vector<float>& pcmf32, int sample_rate,
}

if (freq_thold > 0.0f) {
CC::audio::HighPassFilter(pcmf32, freq_thold, sample_rate);
CC::audio::high_pass_filter(pcmf32, freq_thold, sample_rate);
}

float energy_all = 0.0f;
Expand Down
11 changes: 6 additions & 5 deletions customchar/audio/audio.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ namespace audio {
/// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
/// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain
/// 2 channel PCM
bool ReadWav(const std::string& fname, std::vector<float>& pcmf32,
std::vector<std::vector<float>>& pcmf32s, bool stereo);
bool read_wav(const std::string& fname, std::vector<float>& pcmf32,
std::vector<std::vector<float>>& pcmf32s, bool stereo);

/// @brief Apply a high-pass frequency filter to PCM audio
/// Suppresses frequencies below cutoff Hz
void HighPassFilter(std::vector<float>& data, float cutoff, float sample_rate);
void high_pass_filter(std::vector<float>& data, float cutoff,
float sample_rate);

/// @brief Basic voice activity detection (VAD) using audio energy adaptive
/// threshold
bool VADSimple(std::vector<float>& pcmf32, int sample_rate, int last_ms,
float vad_thold, float freq_thold, bool verbose);
bool vad_simple(std::vector<float>& pcmf32, int sample_rate, int last_ms,
float vad_thold, float freq_thold, bool verbose);

} // namespace audio
} // namespace CC
Expand Down
16 changes: 8 additions & 8 deletions customchar/audio/sdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ AudioAsync::~AudioAsync() {
}
}

bool AudioAsync::Init(int capture_id, int sample_rate) {
bool AudioAsync::initialize(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);

if (SDL_Init(SDL_INIT_AUDIO) < 0) {
Expand Down Expand Up @@ -47,7 +47,7 @@ bool AudioAsync::Init(int capture_id, int sample_rate) {
capture_spec_requested.callback = [](void* userdata, uint8_t* stream,
int len) {
AudioAsync* audio = (AudioAsync*)userdata;
audio->Callback(stream, len);
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;

Expand Down Expand Up @@ -90,7 +90,7 @@ bool AudioAsync::Init(int capture_id, int sample_rate) {
return true;
}

bool AudioAsync::Resume() {
bool AudioAsync::resume() {
if (!m_dev_id_in_) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
Expand All @@ -106,7 +106,7 @@ bool AudioAsync::Resume() {
return true;
}

bool AudioAsync::Pause() {
bool AudioAsync::pause() {
if (!m_dev_id_in_) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
Expand All @@ -123,7 +123,7 @@ bool AudioAsync::Pause() {
return true;
}

bool AudioAsync::Clear() {
bool AudioAsync::clear() {
if (!m_dev_id_in_) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
Expand All @@ -144,7 +144,7 @@ bool AudioAsync::Clear() {
return true;
}

void AudioAsync::Callback(uint8_t* stream, int len) {
void AudioAsync::callback(uint8_t* stream, int len) {
if (!m_running_) {
return;
}
Expand Down Expand Up @@ -177,7 +177,7 @@ void AudioAsync::Callback(uint8_t* stream, int len) {
}
}

void AudioAsync::Get(int ms, std::vector<float>& result) {
void AudioAsync::get(int ms, std::vector<float>& result) {
if (!m_dev_id_in_) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
Expand Down Expand Up @@ -220,7 +220,7 @@ void AudioAsync::Get(int ms, std::vector<float>& result) {
}
}

bool CC::audio::SDLPollEvents() {
bool CC::audio::sdl_poll_events() {
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
Expand Down
14 changes: 7 additions & 7 deletions customchar/audio/sdl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ class AudioAsync {
AudioAsync(int len_ms);
~AudioAsync();

bool Init(int capture_id, int sample_rate);
bool initialize(int capture_id, int sample_rate);

/// Start capturing audio via the provided SDL callback
/// keep last len_ms seconds of audio in a circular buffer
bool Resume();
bool Pause();
bool Clear();
bool resume();
bool pause();
bool clear();

/// @brief Callback function for SDL
/// @param stream Audio stream
/// @param len Length of the stream
void Callback(uint8_t* stream, int len);
void callback(uint8_t* stream, int len);

/// @brief Get audio from the circular buffer
/// @param ms Number of milliseconds to get
/// @param audio Output audio
void Get(int ms, std::vector<float>& audio);
void get(int ms, std::vector<float>& audio);

private:
SDL_AudioDeviceID m_dev_id_in_ = 0;
Expand All @@ -52,7 +52,7 @@ class AudioAsync {
};

// Return false if need to quit
bool SDLPollEvents();
bool sdl_poll_events();

} // namespace audio
} // namespace CC
Expand Down
18 changes: 9 additions & 9 deletions customchar/audio/speech_recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ SpeechRecognizer::SpeechRecognizer(const std::string& model_path_,
fprintf(stderr, "\n");
}

InitPrompt();
init_prompt();
}

SpeechRecognizer::~SpeechRecognizer() {
whisper_print_timings(context_);
whisper_free(context_);
}

void SpeechRecognizer::InitPrompt() {
void SpeechRecognizer::init_prompt() {
const std::string bot_name_ = "CustomChar";
prompt_ = common::Replace(k_prompt_whisper_, "{1}", bot_name_);
prompt_ = common::replace(k_prompt_whisper_, "{1}", bot_name_);
}

std::string SpeechRecognizer::PostProcess(const std::string& text_heard) {
std::string processed_text = common::Trim(text_heard);
std::string SpeechRecognizer::postprocess(const std::string& text_heard) {
std::string processed_text = common::trim(text_heard);

// Remove text between brackets using regex
{
Expand Down Expand Up @@ -91,15 +91,15 @@ std::string SpeechRecognizer::PostProcess(const std::string& text_heard) {
return processed_text;
}

std::string SpeechRecognizer::Recognize(const std::vector<float>& pcmf32,
std::string SpeechRecognizer::recognize(const std::vector<float>& pcmf32,
float& prob, int64_t& t_ms) {
std::string text_heard;
text_heard = Transcribe(pcmf32, prob, t_ms);
text_heard = PostProcess(text_heard);
text_heard = transcribe(pcmf32, prob, t_ms);
text_heard = postprocess(text_heard);
return text_heard;
}

std::string SpeechRecognizer::Transcribe(const std::vector<float>& pcmf32,
std::string SpeechRecognizer::transcribe(const std::vector<float>& pcmf32,
float& prob, int64_t& t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();

Expand Down
8 changes: 4 additions & 4 deletions customchar/audio/speech_recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SpeechRecognizer {
std::string prompt_;

/// @brief Initialize prompt
void InitPrompt();
void init_prompt();

std::string model_path_;
std::string language;
Expand All @@ -40,10 +40,10 @@ class SpeechRecognizer {
bool speed_up;

/// @brief Postprocess text
std::string PostProcess(const std::string& text_heard);
std::string postprocess(const std::string& text_heard);

/// @brief Transcribe speech
std::string Transcribe(const std::vector<float>& pcmf32, float& prob,
std::string transcribe(const std::vector<float>& pcmf32, float& prob,
int64_t& t_ms);

public:
Expand All @@ -61,7 +61,7 @@ class SpeechRecognizer {
/// @param prob Output probability
/// @param t_ms Output time
/// @return Recognized text
std::string Recognize(const std::vector<float>& audio_buff, float& prob,
std::string recognize(const std::vector<float>& audio_buff, float& prob,
int64_t& t_ms);

}; // class SpeechRecognizer
Expand Down
28 changes: 14 additions & 14 deletions customchar/audio/voice_recorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,38 @@ using namespace CC::audio;
VoiceRecorder::VoiceRecorder() {
audio_ = new AudioAsync(30 * 1000);
int capture_id = 0; // TODO: Make this configurable
if (!audio_->Init(capture_id, WHISPER_SAMPLE_RATE)) {
if (!audio_->initialize(capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio_->init() failed!\n", __func__);
exit(1);
}

audio_->Resume();
audio_->resume();
}

void VoiceRecorder::ClearAudioBuffer() { audio_->Clear(); }
void VoiceRecorder::clear_audio_buffer() { audio_->clear(); }

void VoiceRecorder::SampleAudio() { audio_->Get(2000, pcmf32_cur_); }
void VoiceRecorder::sample_audio() { audio_->get(2000, pcmf32_cur_); }

bool VoiceRecorder::FinishedTalking() {
bool VoiceRecorder::finished_talking() {
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool print_energy = false;
return VADSimple(pcmf32_cur_, WHISPER_SAMPLE_RATE, 1250, vad_thold,
freq_thold, print_energy);
return vad_simple(pcmf32_cur_, WHISPER_SAMPLE_RATE, 1250, vad_thold,
freq_thold, print_energy);
}

void VoiceRecorder::GetAudio(std::vector<float>& result) {
void VoiceRecorder::get_audio(std::vector<float>& result) {
int32_t voice_ms = 10000;
audio_->Get(voice_ms, pcmf32_cur_);
audio_->get(voice_ms, pcmf32_cur_);
result = pcmf32_cur_;
}

std::vector<float> VoiceRecorder::RecordSpeech() {
std::vector<float> VoiceRecorder::record_speech() {
bool is_running;
std::vector<float> audio_buff;
while (true) {
// Handle Ctrl + C
is_running = audio::SDLPollEvents();
is_running = audio::sdl_poll_events();
if (!is_running) {
break;
}
Expand All @@ -47,13 +47,13 @@ std::vector<float> VoiceRecorder::RecordSpeech() {
std::this_thread::sleep_for(std::chrono::milliseconds(100));

// Sample audio
SampleAudio();
if (!FinishedTalking()) {
sample_audio();
if (!finished_talking()) {
continue;
}

// Get recorded audio
GetAudio(audio_buff);
get_audio(audio_buff);
break;
}

Expand Down
10 changes: 5 additions & 5 deletions customchar/audio/voice_recorder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ class VoiceRecorder {
VoiceRecorder();

/// @brief Clear audio_ buffer to prepare for new recording
void ClearAudioBuffer();
void clear_audio_buffer();

/// @brief Sample audio_
void SampleAudio();
void sample_audio();

/// @brief Check if finished talking
bool FinishedTalking();
bool finished_talking();

/// @brief Get final audio_
void GetAudio(std::vector<float>& result);
void get_audio(std::vector<float>& result);

/// @brief Record speech from user
/// @return Audio buffer from user
std::vector<float> RecordSpeech();
std::vector<float> record_speech();
}; // class VoiceRecorder

} // namespace audio
Expand Down
Loading

0 comments on commit b9f14c9

Please sign in to comment.