From 0c86af29b3fa9c79c647a76e696ad181e65b3878 Mon Sep 17 00:00:00 2001 From: pajowu Date: Sun, 3 Dec 2023 23:50:53 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20create=5Fapi=5Ftoken=20man?= =?UTF-8?q?agement=20command?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/scripts/create_api_token.py | 2 +- backend/scripts/create_user.py | 2 +- backend/scripts/create_worker.py | 2 +- backend/scripts/reset_task.py | 2 +- backend/scripts/set_password.py | 2 +- .../transcribee_worker/whisper_transcribe.py | 341 ++++++++++++++---- 6 files changed, 266 insertions(+), 85 deletions(-) diff --git a/backend/scripts/create_api_token.py b/backend/scripts/create_api_token.py index 98b602a9..91498482 100644 --- a/backend/scripts/create_api_token.py +++ b/backend/scripts/create_api_token.py @@ -7,6 +7,6 @@ parser = argparse.ArgumentParser() parser.add_argument("--name", required=True) args = parser.parse_args() - with SessionContextManager() as session: + with SessionContextManager(path="management_command:create_api_token") as session: token = create_api_token(session=session, name=args.name) print(f"Token created: {token.token}") diff --git a/backend/scripts/create_user.py b/backend/scripts/create_user.py index 1263b255..6ae3f7d7 100644 --- a/backend/scripts/create_user.py +++ b/backend/scripts/create_user.py @@ -9,7 +9,7 @@ parser.add_argument("--user", required=True) parser.add_argument("--pass", required=True) args = parser.parse_args() - with SessionContextManager(path="mangement_comment:create_user") as session: + with SessionContextManager(path="management_command:create_user") as session: try: user = create_user( session=session, username=args.user, password=getattr(args, "pass") diff --git a/backend/scripts/create_worker.py b/backend/scripts/create_worker.py index 536a6e5f..51c2a52d 100644 --- a/backend/scripts/create_worker.py +++ b/backend/scripts/create_worker.py @@ -13,7 +13,7 @@ if args.token is None: args.token = utils.get_random_string() - with SessionContextManager(path="mangement_comment:create_worker") as session: + with SessionContextManager(path="management_command:create_worker") as session: statement = select(Worker).where(Worker.token == args.token) results = session.exec(statement) existing_worker = results.one_or_none() diff --git a/backend/scripts/reset_task.py b/backend/scripts/reset_task.py index aa9b8f3f..08528eb8 100644 --- a/backend/scripts/reset_task.py +++ b/backend/scripts/reset_task.py @@ -12,7 +12,7 @@ "--uuid", required=True, type=uuid.UUID, help="Task UUID or Document UUID" ) args = parser.parse_args() - with SessionContextManager(path="mangement_comment:reset_task") as session: + with SessionContextManager(path="management_command:reset_task") as session: task = session.execute( update(Task) .where( diff --git a/backend/scripts/set_password.py b/backend/scripts/set_password.py index de63dd43..f661bb3a 100644 --- a/backend/scripts/set_password.py +++ b/backend/scripts/set_password.py @@ -10,7 +10,7 @@ parser.add_argument("--pass", required=True) args = parser.parse_args() - with SessionContextManager(path="mangement_comment:set_password") as session: + with SessionContextManager(path="management_command:set_password") as session: try: user = change_user_password( session=session, username=args.user, new_password=getattr(args, "pass") diff --git a/worker/transcribee_worker/whisper_transcribe.py b/worker/transcribee_worker/whisper_transcribe.py index 7f5d6ab1..caec0dcf 100644 --- a/worker/transcribee_worker/whisper_transcribe.py +++ b/worker/transcribee_worker/whisper_transcribe.py @@ -1,14 +1,14 @@ +import logging import re -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, List, Optional -import faster_whisper -import faster_whisper.transcribe -from faster_whisper import WhisperModel +import requests from numpy.typing import NDArray from transcribee_proto.document import Atom, Paragraph from transcribee_worker.config import settings from transcribee_worker.types import ProgressCallbackType from transcribee_worker.util import SubmissionQueue, async_task +from whispercppy import api if TYPE_CHECKING: from .icu import BreakIterator, Locale @@ -29,22 +29,189 @@ ] -def move_space_to_prev_token( - iter: Iterator[Paragraph], -) -> Iterator[Paragraph]: - last_paragraph = next(iter) - last_paragraph.children[0].text = last_paragraph.children[0].text.lstrip() - _para_move_space_to_prev_token(last_paragraph) +def get_model_file(model_name: str): + whisper_models_dir = settings.MODELS_DIR / "whisper" + whisper_models_dir.mkdir(parents=True, exist_ok=True) + model_file = whisper_models_dir / f"{model_name}.bin" - for paragraph in iter: - para_starts_with_whitespace = paragraph.children[0].text[:1].isspace() - if para_starts_with_whitespace: - last_paragraph.children[-1].text += paragraph.children[0].text[:1] - paragraph.children[0].text = paragraph.children[0].text[1:] + if not model_file.exists(): + logging.info(f"downloading model {model_name} because it does not exist yet...") + base_url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" + url = f"{base_url}/ggml-{model_name}.bin" + r = requests.get(url, allow_redirects=True) + r.raise_for_status() + with model_file.open(mode="wb") as f: + f.write(r.content) + + return model_file + + +def get_context(model_name: str) -> api.Context: + model_file = get_model_file(model_name) + logging.info(f"loading model {model_name}...") + ctx = api.Context.from_file(str(model_file)) + ctx.reset_timings() + return ctx + + +# TODO(robin): this currently filters all special tokens +# recovery of multilingual text could be hard if we keep this filtering +def _transcription_work( + queue: SubmissionQueue, + data: NDArray[Any], + start_offset: float, + model_name: str, + lang_code: Optional[str], + progress_callback: Optional[ProgressCallbackType], +): + def handle_new_segment( + ctx: api.Context, + n_new: int, + queue: SubmissionQueue, + ): + segment = ctx.full_n_segments() - n_new + + rest_token_bytes = b"" + rest_conf = 0 + rest_count = 0 + rest_start = 0 + rest_conf_ts = 0 + + lang: str + if lang_code is None or lang_code in ["", "auto"]: + lang = ctx.lang_id_to_str(ctx.full_lang_id()) + else: + lang = lang_code + + while segment < ctx.full_n_segments(): + tokens = ( + ctx.full_get_token_data(segment, token_idx) + for token_idx in range(ctx.full_n_tokens(segment)) + ) + + atoms = [] + for token in tokens: + if token.id in special_tokens or token.id > special_tokens[-1]: + continue + + token_bytes = ctx.token_to_bytes(token.id) + conf = token.p + conf_ts = token.pt + start = token.t0 + end = token.t1 + + # tokens can be incomplete utf-8, so we sometimes need to combine tokens to + # get valid utf we assume this invalid utf cannot span multiple segments + try: + text = (rest_token_bytes + token_bytes).decode("utf-8") + conf = (rest_conf + conf) / (rest_count + 1) + conf_ts = (rest_conf_ts + conf_ts) / (rest_count + 1) + if rest_start != 0: + start = rest_start + except UnicodeDecodeError: + logging.info( + "invalid utf-8 encountered in whisper token, skipping decoding, " + "appending to rest" + ) + rest_token_bytes += token_bytes + rest_conf += conf + rest_count += 1 + rest_conf_ts += conf_ts + if rest_start != 0: + rest_start = start + continue + + rest_token_bytes = b"" + rest_conf = 0 + rest_conf_ts = 0 + rest_count = 0 + rest_start = 0 + + atoms.append( + Atom( + text=text, + conf=conf, + # 10·ms -> seconds + start=(start / 100) + start_offset, + # 10·ms -> seconads + end=(end / 100) + start_offset, + conf_ts=conf_ts, + ) + ) + + paragraph = Paragraph( + children=atoms, + lang=lang, + ) + + queue.submit(paragraph) + segment += 1 + + ctx = get_context(model_name) + + special_tokens = [ + ctx.eot_token, # type: ignore + ctx.sot_token, # type: ignore + ctx.prev_token, # type: ignore + ctx.solm_token, # type: ignore + ctx.not_token, # type: ignore + ctx.beg_token, # type: ignore + ] + + sampling = api.SamplingStrategies.from_enum(api.SAMPLING_GREEDY) + sampling.greedy.best_of = 5 # parameter stolen from whisper.cpp cli + params = ( + api.Params.from_sampling_strategy(sampling) + .with_no_context( + False + ) # if False, feeds back already transcribed text back to the model + .with_num_threads(4) + .with_max_segment_length(0) # Unlimited segment length + .with_token_timestamps(True) + ) + if lang_code is not None: + params = params.with_language(lang_code) + params.on_new_segment(handle_new_segment, queue) + if progress_callback is not None: + params.on_progress( + lambda _ctx, progress, _data: progress_callback(progress=progress / 100), + None, + ) + ctx.full(params, data) + + +def transcribe( + data: NDArray, + start_offset: float, + model_name: str, + lang_code="en", + progress_callback=None, +) -> AsyncIterator[Paragraph]: + return async_task( + _transcription_work, + data, + start_offset, + model_name, + lang_code, + progress_callback, + ) - yield last_paragraph - paragraph = _para_move_space_to_prev_token(paragraph) - last_paragraph = paragraph + +async def recombine_split_words( + iter: AsyncIterator[Paragraph], +) -> AsyncIterator[Paragraph]: + last_paragraph = None + async for paragraph in iter: + if last_paragraph is None: + last_paragraph = paragraph + continue + + starts_with_whitespace = paragraph.text()[:1].isspace() + if starts_with_whitespace: + yield last_paragraph + last_paragraph = paragraph + else: + last_paragraph.children.extend(paragraph.children) if last_paragraph is not None: yield last_paragraph @@ -59,33 +226,34 @@ def _para_move_space_to_prev_token(paragraph: Paragraph): return paragraph -def whisper_segment_to_transcribee_segment( - iter: Iterator[faster_whisper.transcribe.Segment], lang: str, start_offset: float -) -> Iterator[Paragraph]: - for seg in iter: - assert seg.words is not None - yield Paragraph( - children=[ - Atom( - text=word.word, - start=word.start + start_offset, - end=word.end + start_offset, - conf=word.probability, - conf_ts=1, - ) - for word in seg.words - ], - lang=lang, - ) +async def move_space_to_prev_token( + iter: AsyncIterator[Paragraph], +) -> AsyncIterator[Paragraph]: + last_paragraph = await anext(iter) + last_paragraph.children[0].text = last_paragraph.children[0].text.lstrip() + _para_move_space_to_prev_token(last_paragraph) + async for paragraph in iter: + para_starts_with_whitespace = paragraph.children[0].text[:1].isspace() + if para_starts_with_whitespace: + last_paragraph.children[-1].text += paragraph.children[0].text[:1] + paragraph.children[0].text = paragraph.children[0].text[1:] + + yield last_paragraph + paragraph = _para_move_space_to_prev_token(paragraph) + last_paragraph = paragraph -def strict_sentence_paragraphs( - iter: Iterator[Paragraph], -) -> Iterator[Paragraph]: + if last_paragraph is not None: + yield last_paragraph + + +async def strict_sentence_paragraphs( + iter: AsyncIterator[Paragraph], +) -> AsyncIterator[Paragraph]: acc_paragraph = None acc_used_paras = [] combination_active = True - for paragraph in iter: + async for paragraph in iter: if not combination_active: yield paragraph continue @@ -161,58 +329,71 @@ def strict_sentence_paragraphs( yield acc_paragraph -def transcribe_clean( - queue: SubmissionQueue, +async def combine_tokens_to_words( + iter: AsyncIterator[Paragraph], +) -> AsyncIterator[Paragraph]: + async for paragraph in iter: + locale = Locale(paragraph.lang) + word_iter = BreakIterator.createWordInstance(locale) + word_iter.setText(paragraph.text()) + breaks: List[int] = list(word_iter) + assert breaks[-1] == len(paragraph.text()) + + new_para = Paragraph( + children=[], speaker=paragraph.speaker, lang=paragraph.lang + ) + pos = 0 + current_atom = None + for atom in paragraph.children: + pos_after_atom = pos + len(atom.text) + if current_atom is None: + current_atom = Atom( + text=atom.text, + conf=atom.conf, + start=atom.start, + end=atom.end, + conf_ts=atom.conf_ts, + ) + pos = pos_after_atom + else: + current_atom.text += atom.text + current_atom.end = atom.end + current_atom.conf = min(current_atom.conf, atom.conf) + current_atom.conf_ts = min(current_atom.conf_ts, atom.conf_ts) + pos = pos_after_atom + + if pos_after_atom in breaks: + new_para.children.append(current_atom) + current_atom = None + + if current_atom is not None: + new_para.children.append(current_atom) + yield new_para + + +async def transcribe_clean( data: NDArray, - sr: int, start_offset: float, model_name: str, - progress_callback: ProgressCallbackType, - lang_code: Optional[str] = "en", + lang_code: str = "en", + progress_callback=None, ): chain = ( + recombine_split_words, move_space_to_prev_token, strict_sentence_paragraphs, + combine_tokens_to_words, ) - model = WhisperModel( - model_size_or_path=model_name, - download_root=str((settings.MODELS_DIR / "faster_whisper").absolute()), - ) - seg_iter, info = model.transcribe( - audio=data, word_timestamps=True, language=lang_code - ) - seg_iter = whisper_segment_to_transcribee_segment( - iter(seg_iter), lang=info.language, start_offset=start_offset - ) - total_len = len(data) / sr - for elem in chain: - seg_iter = elem(seg_iter) - for v in seg_iter: - queue.submit(v) - if progress_callback is not None and v.children: - progress = (v.children[-1].end - start_offset) / total_len - progress_callback( - progress=progress, - step="transcribe", - ) - - -def transcribe_clean_async( - data: NDArray, - sr: int, - start_offset: float, - model_name: str, - progress_callback: ProgressCallbackType, - lang_code: Optional[str] = "en", -): - return aiter( - async_task( - transcribe_clean, + iter = aiter( + transcribe( data=data, - sr=sr, start_offset=start_offset, model_name=model_name, lang_code=lang_code, progress_callback=progress_callback, ) ) + for elem in chain: + iter = elem(iter) + async for v in iter: + yield v