Skip to content

Commit

Permalink
seed, multilingual and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiltseb committed Jun 9, 2023
1 parent 1bb7e33 commit fc54cb9
Showing 1 changed file with 65 additions and 6 deletions.
71 changes: 65 additions & 6 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TranscriptionOptions(NamedTuple):
patience: float
length_penalty: float
log_prob_threshold: Optional[float]
log_prob_low_threshold: Optional[float]
no_speech_threshold: Optional[float]
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
Expand All @@ -61,7 +62,8 @@ class TranscriptionOptions(NamedTuple):
word_timestamps: bool
prepend_punctuations: str
append_punctuations: str

multilingual: bool
output_language: str

class TranscriptionInfo(NamedTuple):
language: str
Expand All @@ -79,7 +81,7 @@ def __init__(
device: str = "auto",
device_index: Union[int, List[int]] = 0,
compute_type: str = "default",
cpu_threads: int = 0,
cpu_threads: int = 16,
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
Expand Down Expand Up @@ -120,6 +122,7 @@ def __init__(
cache_dir=download_root,
)

ctranslate2.set_random_seed(42)
self.model = ctranslate2.models.Whisper(
model_path,
device=device,
Expand Down Expand Up @@ -168,6 +171,7 @@ def transcribe(
],
compression_ratio_threshold: Optional[float] = 2.4,
log_prob_threshold: Optional[float] = -1.0,
log_prob_low_threshold: Optional[float] = -2.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
Expand All @@ -179,6 +183,8 @@ def transcribe(
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
multilingual: bool = False,
output_language: Optional[str] = None,
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
Expand All @@ -201,6 +207,9 @@ def transcribe(
treat as failed.
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
wheras log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
Expand All @@ -221,6 +230,10 @@ def transcribe(
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
multilingual: If True, perform transcription on multilingual videos and return the transcript based
on the 'output_language' flag.
output_language: Valid only if multilingual is set to True. Specifies the string representing the output language. One of
'en' (English) or 'hybrid' (code-switched transcription).
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
Expand Down Expand Up @@ -278,6 +291,19 @@ def transcribe(
encoder_output = None
all_language_probs = None

if not multilingual:
if output_language is not None:
self.logger.info(
"No need to set the output language for mono-lingual videos. Ignoring the parameter..."
)
else:
if output_language is None:
output_language = "en"
elif output_language not in ["en","hybrid"]:
output_language = "en"
self.logger.info(
"Output language needs to be one of 'en'/'hybrid'. Setting to default language:'en'"
)
if language is None:
if not self.model.is_multilingual:
language = "en"
Expand Down Expand Up @@ -314,6 +340,7 @@ def transcribe(
patience=patience,
length_penalty=length_penalty,
log_prob_threshold=log_prob_threshold,
log_prob_low_threshold = log_prob_low_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
condition_on_previous_text=condition_on_previous_text,
Expand All @@ -329,6 +356,8 @@ def transcribe(
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
multilingual = multilingual,
output_language = output_language,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
Expand Down Expand Up @@ -379,16 +408,35 @@ def generate_segments(
)

previous_tokens = all_tokens[prompt_reset_since:]

if encoder_output is None:
encoder_output = self.encode(segment)

if options.multilingual and seek != 0: # language is already detected for first segment
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
if options.output_language == "en" and language != "en":
task = "translate"
else:
task = "transcribe"

tokenizer = Tokenizer(
self.hf_tokenizer,
self.model.is_multilingual,
task=task,
language=language,
)
#print(language)
#print(task)

prompt = self.get_prompt(
tokenizer,
previous_tokens,
without_timestamps=options.without_timestamps,
prefix=options.prefix if seek == 0 else None,
)

if encoder_output is None:
encoder_output = self.encode(segment)

(
result,
avg_logprob,
Expand All @@ -406,7 +454,12 @@ def generate_segments(
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False


if (
# skip if the logprob is very low, despite no_speech_prob being low (ex: Too ambiguous outputs for input music and noise)
avg_logprob < options.log_prob_low_threshold
):
should_skip = True
if should_skip:
self.logger.debug(
"No speech threshold is met (%f > %f)",
Expand Down Expand Up @@ -645,6 +698,12 @@ def generate_with_fallback(
options.log_prob_threshold,
)

if (
options.no_speech_threshold is not None
and result.no_speech_prob > options.no_speech_threshold
):
needs_fallback = False # silence

if not needs_fallback:
break

Expand Down

0 comments on commit fc54cb9

Please sign in to comment.