Skip to content

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
Added code for faster-whisper and distilwhisper from here:
Uberi#730
  • Loading branch information
SwamiKannan committed Jul 1, 2024
1 parent 75a7f6b commit 848e20f
Showing 1 changed file with 83 additions and 1 deletion.
84 changes: 83 additions & 1 deletion speech_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
__version__ = "3.10.4"
__license__ = "BSD"

MODEL_PATH = 'N:\\models\\voice\\model\\'
TOKENIZER_PATH = 'N:\\models\\voice\\tokenizer\\'

class AudioSource(object):
def __init__(self):
Expand Down Expand Up @@ -324,7 +326,7 @@ def read(self, size=-1):


class Recognizer(AudioSource):
def __init__(self):
def __init__(self, model = None):
"""
Creates a new ``Recognizer`` instance, which represents a collection of speech recognition functionality.
"""
Expand All @@ -337,6 +339,7 @@ def __init__(self):

self.phrase_threshold = 0.3 # minimum seconds of speaking audio before we consider the speaking audio a phrase - values below this are ignored (for filtering out clicks and pops)
self.non_speaking_duration = 0.5 # seconds of non-speaking audio to keep on both sides of the recording
self.model = model

def record(self, source, duration=None, offset=None):
"""
Expand Down Expand Up @@ -1385,7 +1388,86 @@ def recognize_tensorflow(self, audio_data, tensor_graph='tensorflow-data/conv_ac
for node_id in top_k:
human_string = self.tflabels[node_id]
return human_string
def recognize_fasterwhisper(self, audio_data, model="small", show_dict=False, load_options=None, language=None, translate=False, **transcribe_options):
#custom recognizer for faster whisper
assert isinstance(audio_data, AudioData), "Data must be audio data"
import numpy as np
import soundfile as sf
import torch
from faster_whisper import WhisperModel

if load_options or not hasattr(self, "whisper_model") or self.whisper_model.get(model) is None:
self.whisper_model = getattr(self, "whisper_model", {})
#self.whisper_model[model] = WhisperModel("base", device="cpu", compute_type="int8")
self.whisper_model[model] = WhisperModel("tiny", device="cuda", compute_type="auto")

wav_bytes = audio_data.get_wav_data(convert_rate=16000)
wav_stream = io.BytesIO(wav_bytes)
audio_array, sampling_rate = sf.read(wav_stream)
audio_array = audio_array.astype(np.float32)

segments, info = self.whisper_model[model].transcribe(audio_array, beam_size=5,)
text =""
for segment in segments:
#print("%s " % (segment.text))
text=text+segment.text+" "
#print(text)
if show_dict:
return result
else:
return text.lower()

def recognize_distilwhisper(self, audio_data, model="distil-whisper/distil-medium.en", show_dict=False, load_options=None, language=None, translate=False, **transcribe_options):
#custom recognizer for distill-whisper
assert isinstance(audio_data, AudioData), "Data must be audio data"
import numpy as np
import soundfile as sf
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging, AutoTokenizer

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model_name = model_id.split('/')[1]
model_cache_path = MODEL_PATH+model_name
tokenizer_cache_path = TOKENIZER_PATH+model_name

if not self.model:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, cache_dir = model_cache_path)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id, cache_dir=tokenizer_cache_path)

whisper = pipeline(
"automatic-speech-recognition",model=model,tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,max_new_tokens=128,
torch_dtype=torch_dtype,device=device)
if not os.path.exists(tokenizer_cache_path):
processor.save_pretrained(tokenizer_cache_path)
if not os.path.exists(model_cache_path):
model.save_pretrained(model_cache_path)
wav_bytes = audio_data.get_wav_data(convert_rate=16000)
wav_stream = io.BytesIO(wav_bytes)
audio_array, sampling_rate = sf.read(wav_stream)
audio_array = audio_array.astype(np.float16)

if not self.model:
print('Imported whisper not detected')
text = whisper(audio_array,
chunk_length_s=50,
stride_length_s=10,
batch_size=8)
else:
print('Imported whisper HAS BEEN DETECTED')
text = self.model(audio_array,
chunk_length_s=50,
stride_length_s=10,
batch_size=8)
if show_dict:
return result
else:
return text["text"]

def recognize_whisper(self, audio_data, model="base", show_dict=False, load_options=None, language=None, translate=False, **transcribe_options):
"""
Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using Whisper.
Expand Down

0 comments on commit 848e20f

Please sign in to comment.