diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 9df4f0b76..614159224 100644 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -125,15 +125,11 @@ def extract_text_and_time_tokens(input_sequence): word_length = [] for idx, word in enumerate(words): # Tokenize the word using the provided text processor - if id == 0: - logging.debug(f'word: {word}') tokenized_word = self.text_processor._process_example(context="", output=word) # Remove the EOS token (assuming the EOS token is at the end of "answer_ids") token_ids = tokenized_word["answer_ids"][:-1] # Remove EOS token if idx != 0: # If not the first word, remove the first token token_ids = token_ids[1:] - if id == 0: - logging.debug(f'token_ids: {token_ids}') token_length = len(token_ids) # Calculate the length tokenized_words.extend(token_ids) word_length.append(token_length) @@ -184,8 +180,6 @@ def _expand_text_with_timestamps_and_word_lengths( # Reduction of start time index due to stacking of frames start_time_index = int(start_time_index / self.decoder_reduction_factor) - if batch_idx == 0: - logging.debug(f'start_time_index[0]: {start_time_index}') # Calculate the end time index based on word length end_time_index = start_time_index + word_length @@ -210,8 +204,6 @@ def _expand_text_with_timestamps_and_word_lengths( cuts = cuts.sort_by_duration() - logging.debug(f"Len: {len(cuts)}") - metadata = [] instructions, instruction_lengths = [], [] target_texts, target_text_lengths = [], [] @@ -311,8 +303,6 @@ def collate_and_pad(inputs): tokens[i, : token_lengths[i], :] = inputs[i] return tokens, torch.LongTensor(token_lengths) - # import pdb; pdb.set_trace() - target_codec = None answer_audios, answer_audio_lens = None, None assert self.load_answer_audio @@ -390,8 +380,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): texts_expanded = texts_expanded[:, :-1] return texts, text_lengths, texts_expanded - # import pdb; pdb.set_trace() - unpadded_target_texts = target_texts target_texts, target_text_lengths, target_texts_expanded = _convert_text_to_3d_tensor(target_texts) instructions, instruction_lengths, instructions_expanded_no_eos = _convert_text_to_3d_tensor( @@ -405,7 +393,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): # TODO: remove the following stanza if getattr(cut, "s2s", False): - logging.debug(f'target_texts_expanded: {target_texts_expanded[0,:]}') # Add 1 for eos token token_list = [ torch.concat([tt[: ttl + 1], tc[: tcl + 1]], 0) @@ -446,19 +433,11 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): self.codec_model_downsampling_factor / self.sample_rate, pad_id=text_unk_id, ) - # import pdb; pdb.set_trace() - logging.debug(f'start_time_token: {start_time_tokens[0]}') - logging.debug(f'word_length: {word_lengths[0]}') - logging.debug(f'target_tokens: {unpadded_target_texts[0]}') - logging.debug(f'target_texts_expanded: {target_texts_expanded[0,:]}') # [batch, max_feat_len, 1+V], where V = #codebooks * reduction_factor target_codec[:, :, 0] = target_texts_expanded token_list = torch.concat([bos_tensor, target_codec], 1) features_lens += 1 - # import pdb; pdb.set_trace() - - logging.debug(f'token_list[0].shape: {token_list[0].shape}') if not self.t5_style: token_list = [ torch.concat([it[:itl], tt], 0) @@ -473,7 +452,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): loss_mask[:, :itl, :] = False full_lengths = features_lens + 1 + instruction_lengths target_text_lengths = -1 * torch.ones_like(target_text_lengths) - # import pdb; pdb.set_trace() elif getattr(cut, "direct_s2s", False): # Add 1 for eos token # tt[0] is the bos token diff --git a/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py b/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py index 7bf6b33ca..cabcd2e66 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py @@ -399,27 +399,38 @@ def inference_step(self, dataloader_iter, mode): return outputs def parse_decoder_outputs( - self, input_decoder_output, text_separator, context_length, speech_pad_id=1001, speech_eos_id=1004 + self, input_decoder_output, text_separator, context_length, speech_pad_id=1001, speech_eos_id=1004, text_pad_id=0, ): # remove text context max_len = input_decoder_output.shape[0] decoder_output = input_decoder_output[-1:].tile([max_len, 1]) decoder_output[: max_len - context_length] = input_decoder_output[context_length:] - # Do not split because text and speech are now aligned - # Split text and speech part based on the position of the first separator token - # sep_pos = (decoder_output[:, 0] == text_separator).long() - # if torch.any(sep_pos): - # first_sep_pos = torch.argmax(sep_pos) - # text_tokens = decoder_output[:first_sep_pos, 0] - # speech_tokens = decoder_output[first_sep_pos + 1 :, 1:] - # else: - # text_tokens = decoder_output[:, 0] - # speech_tokens = decoder_output[:, 1:] - text_tokens = decoder_output[:, 0] - speech_tokens = decoder_output[:, 1:] - - # import pdb; pdb.set_trace() + text_channel = decoder_output[:, 0] + + # adhoc: Suppose the text_pad_id appear before text_eos in the align_s2s case + sep_indices = (text_channel == text_separator).nonzero(as_tuple=True)[0] + index_sep = sep_indices[0].item() if sep_indices.numel() > 0 else None + pad_indices = (text_channel == text_pad_id).nonzero(as_tuple=True)[0] + index_pad = pad_indices[0].item() if pad_indices.numel() > 0 else None + is_align_s2s = index_pad is not None and (index_sep is None or index_pad < index_sep) + is_align_s2s = True + if is_align_s2s: + text_tokens_with_pads = decoder_output[:, 0] + text_tokens = text_tokens_with_pads[text_tokens_with_pads != text_pad_id] + speech_tokens = decoder_output[:, 1:] + else: + # s2s predicts [text, text_separator] for the first channel + sep_pos = (text_channel == text_separator).long() + is_s2s = torch.any(sep_pos) + if is_s2s: + first_sep_pos = torch.argmax(sep_pos) + text_tokens = decoder_output[:first_sep_pos, 0] + speech_tokens = decoder_output[first_sep_pos + 1 :, 1:] + else: + # direct_s2s + text_tokens = decoder_output[:, 0] + speech_tokens = decoder_output[:, 1:] # Get speech token ids n_speech_codebooks = self.model.n_proj_heads - 1 @@ -824,7 +835,6 @@ def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir def de_concat_multiproj_logits(self, logits): logits_list = [] prev = 0 - # import pdb; pdb.set_trace() for i in self.model.proj_head_dims: logits_list.append(logits[:, prev : prev + i]) prev += i