From 19022e6463b5534b8e5afedf58873fa6c29a7645 Mon Sep 17 00:00:00 2001 From: kevinhu Date: Tue, 26 Nov 2024 06:48:31 -0800 Subject: [PATCH] fix sample_rate issues --- .../multimodal/speech_llm/data/build_dataset.py | 4 ++++ .../multimodal/speech_llm/data/lhotse_dataset.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 5627a6665..5e669ce3c 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -93,6 +93,10 @@ def build_speechllm_dataset(model_instance, data_cfg, is_train): speech_eos_id=data_cfg.get('speech_eos_id', 1004), filter_by_source_target_text_ratio=data_cfg.get('filter_by_source_target_text_ratio', False), source_target_text_ratio_limit=data_cfg.get('source_target_text_ratio_limit', 1.0), + load_answer_audio=data_cfg.get('load_answer_audio', False), + codec_model_downsampling_factor=data_cfg.get('codec_model_downsampling_factor', 1024), + sample_rate=data_cfg.get('sample_rate', 16000), + codec_sample_rate=data_cfg.get('target_audio_sample_rate', 22050), ) # Notably, the data weights are controlled by either bucketing_weights diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 5130e9946..072b684a6 100644 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -67,7 +67,8 @@ def __init__( speech_eos_id: int = 1004, filter_by_source_target_text_ratio: bool = False, source_target_text_ratio_limit: float = 1.0, - sample_rate: int = 22050, + sample_rate: int = 16000, + codec_sample_rate: int = 22050, t5_style: bool = False, load_answer_audio: bool = False, codec_model_downsampling_factor: float = 1023.5, @@ -95,6 +96,7 @@ def __init__( self.filter_by_source_target_text_ratio = filter_by_source_target_text_ratio self.source_target_text_ratio_limit = source_target_text_ratio_limit self.sample_rate = sample_rate + self.codec_sample_rate = codec_sample_rate self.load_answer_audio = load_answer_audio self.codec_model_downsampling_factor = codec_model_downsampling_factor @@ -338,7 +340,7 @@ def collate_and_pad(inputs): answer_audios = [] features_lens = [] for i, cut in enumerate(cuts): - answer_audio = torch.tensor(cut.target_audio.load_audio()).float() + answer_audio = torch.tensor(cut.target_audio.resample(self.codec_sample_rate).load_audio()).float() answer_audio_len = torch.tensor(answer_audio.shape[1]).long() answer_audios.append(answer_audio) answer_audio_lens.append(answer_audio_len) @@ -433,7 +435,7 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): word_lengths, start_time_tokens, features_lens + 1, - self.codec_model_downsampling_factor / self.sample_rate, + self.codec_model_downsampling_factor / self.codec_sample_rate, pad_id=text_unk_id, ) else: