From 07192dd03fa8fbef024bb3d7523f103110c1af24 Mon Sep 17 00:00:00 2001 From: kevinhu Date: Tue, 10 Dec 2024 13:04:15 -0800 Subject: [PATCH] add support for using aligned_speech_text training for duplex --- .../speech_llm/data/lhotse_dataset.py | 306 +++++++++++------- 1 file changed, 189 insertions(+), 117 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 5c2732586..d0fa89d26 100755 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -1,5 +1,6 @@ import math import random +import re import torch.utils.data from lhotse import CutSet @@ -107,6 +108,106 @@ def __init__( if self.codec_sample_rate != self.sample_rate: logging.info(f'{self.codec_sample_rate} {self.sample_rate} are different') + def _extract_text_and_time_tokens(self, input_sequence): + # Regular expression to match time tokens (e.g., <|x|> where x is an integer) + time_token_pattern = r"<\|(\d+)\|>" + # Find all time tokens + time_tokens = re.findall(time_token_pattern, input_sequence) + # Only keep the first token of every pair (i.e., start time tokens) + start_time_token = [int(time_tokens[i]) for i in range(0, len(time_tokens), 2)] + # Remove all time tokens to isolate words + words = re.sub(time_token_pattern, '', input_sequence).split() + # Process each word, tokenize it, and calculate token lengths + tokenized_words = [] + word_length = [] + for idx, word in enumerate(words): + # Tokenize the word using the provided text processor + if id == 0: + logging.debug(f'word: {word}') + tokenized_word = self.text_processor._process_example(context="", output=word) + # Remove the EOS token (assuming the EOS token is at the end of "answer_ids") + token_ids = tokenized_word["answer_ids"][:-1] # Remove EOS token + if idx != 0: # If not the first word, remove the first token + token_ids = token_ids[1:] + if id == 0: + logging.debug(f'token_ids: {token_ids}') + token_length = len(token_ids) # Calculate the length + tokenized_words.extend(token_ids) + word_length.append(token_length) + return ( + torch.as_tensor(tokenized_words), + torch.as_tensor(start_time_token), + torch.as_tensor(word_length), + ) + + def _expand_text_with_timestamps_and_word_lengths( + self, word_tokens, word_lengths, start_time_tokens, features_lens, frame_rate=0.08, pad_id=None + ): + """ + Expand word tokens according to start time tokens and word lengths for a batch of sequences. + + Args: + - word_tokens: List of lists of token sequences (each inner list is a word's token IDs), shape [batch][time]. + - word_lengths: List of lists of word lengths, shape [batch][time]. + - start_time_tokens: List of lists of start times, shape [batch][time]. + - max_length: Maximum length in the time dimension (number of frames). + - frame_rate: Frame rate resolution. + - pad_id: Padding ID to use for empty positions in the tensor. + + Returns: + - 2D tensor [batch, max_length] where each row is the expanded token sequence for that batch. + """ + def discretize_time(start_token, speech_resolution=0.08, timestamp_resolution=0.08): + """Convert the start token into a time index based on the resolution.""" + return int(start_token * timestamp_resolution / speech_resolution) + + if pad_id is None: + raise ValueError("pad_id must be provided.") + + batch_size = len(word_tokens) + max_length = max(features_lens).item() + + # Create the empty 2D tensor [batch, max_length] with pad_id as the default value + texts_expanded = torch.full((batch_size, max_length), fill_value=pad_id, dtype=torch.long) + + # Iterate over each batch + for batch_idx in range(batch_size): + # Remove the speech eos + batch_max_length = features_lens[batch_idx] - 1 + word_start_idx = 0 # Start index to keep track of the position within the concatenated word tokens + + for word_idx, word_length in enumerate(word_lengths[batch_idx]): + start_token = start_time_tokens[batch_idx][word_idx] + + # Convert the start time token into a time index based on frame rate + start_time_index = discretize_time(start_token, frame_rate) + + # Reduction of start time index due to stacking of frames + start_time_index = int(start_time_index / self.decoder_reduction_factor) + if batch_idx == 0: + logging.debug(f'start_time_index[0]: {start_time_index}') + + end_time_index = start_time_index + word_length + end_time_index = min(end_time_index, max_length) + + # Get the word tokens for the current word + word_token_ids = word_tokens[batch_idx][word_start_idx : word_start_idx + word_length] + + # Populate the tokens in the expanded tensor at the correct positions + for t_idx in range(start_time_index, end_time_index): + if t_idx - start_time_index < len(word_token_ids): # Ensure tokens are within bounds + token_id = word_token_ids[t_idx - start_time_index] # Get token for this time step + texts_expanded[batch_idx][t_idx] = token_id # Directly assign the token ID + + # Move to the next word in the concatenated word tokens + word_start_idx += word_length + + # Overwrite padding tokens + # import pdb; pdb.set_trace() + texts_expanded[batch_idx][batch_max_length:] = self.text_processor.pad_id + + return texts_expanded + def __getitem__duplex_(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]: import re @@ -115,6 +216,7 @@ def __getitem__duplex_(self, cuts) -> dict[str, torch.Tensor | list[str] | dict] metadata = [] instructions, instruction_lengths = [], [] target_texts, target_text_lengths = [], [] + source_texts, source_text_lengths = [], [] remove_ids = [] start_time_tokens, word_lengths = [], [] num_turns = [] @@ -123,6 +225,7 @@ def __getitem__duplex_(self, cuts) -> dict[str, torch.Tensor | list[str] | dict] for id, cut in enumerate(cuts): num_turns.append(len(cut.supervisions)) metadata.append({'audio_filepath': cut.id + '.wav'}) + # logging.debug(f'reading: {cut.id}.wav') text_start_time.append([]) text_end_time.append([]) # treat multiturn data as multiple batch each with 2-turn conversation @@ -138,14 +241,31 @@ def __getitem__duplex_(self, cuts) -> dict[str, torch.Tensor | list[str] | dict] raise Exception("First speaker should be user") if supervisions[1].speaker == "agent": - pattern = r"<\|\d+\|>" - output_text = re.sub(pattern, "", supervisions[1].text) - output_text = re.sub(r'\s+', ' ', output_text).strip() - target_text = self.text_processor._process_example(context="", output=output_text) - # -1 to remove the eos token added by the text processor - target_text, target_text_length = torch.as_tensor(target_text["answer_ids"][:-1]), torch.as_tensor( - len(target_text["answer_ids"]) - 1 - ) + use_word_alignment = getattr(cut, "s2s_duplex_align", False) + text = supervisions[1].text + if not use_word_alignment: + # import pdb; pdb.set_trace() + pattern = r"<\|\d+\|>" + output_text = re.sub(pattern, "", supervisions[1].text) + output_text = re.sub(r'\s+', ' ', output_text).strip() + # logging.debug(f'target_text: {output_text}') + target_text = self.text_processor._process_example(context="", output=output_text) + # -1 to remove the eos token added by the text processor + target_text, target_text_length = torch.as_tensor(target_text["answer_ids"][:-1]), torch.as_tensor( + len(target_text["answer_ids"]) - 1 + ) + # Extract user text + output_text = re.sub(pattern, "", supervisions[0].text) + output_text = re.sub(r'\s+', ' ', output_text).strip() + # logging.debug(f'source_text: {output_text}') + source_text = self.text_processor._process_example(context="", output=output_text) + # -1 to remove the eos token added by the text processor + source_text, source_text_length = torch.as_tensor(source_text["answer_ids"][:-1]), torch.as_tensor( + len(source_text["answer_ids"]) - 1 + ) + else: + target_text, start_time_token, word_length = self._extract_text_and_time_tokens(text) + target_text_length = len(target_text) else: raise Exception("Second speaker should be agent") @@ -153,8 +273,13 @@ def __getitem__duplex_(self, cuts) -> dict[str, torch.Tensor | list[str] | dict] instruction_lengths.append(instruction_length) target_texts.append(target_text) target_text_lengths.append(target_text_length) + source_texts.append(source_text) + source_text_lengths.append(source_text_length) text_start_time[-1].append(supervisions[1].start) text_end_time[-1].append(supervisions[1].start + supervisions[1].duration) + if use_word_alignment: + word_lengths.append(word_length) + start_time_tokens.append(start_time_token) answer_audios, answer_audio_lens = None, None assert self.load_answer_audio @@ -202,6 +327,7 @@ def load_audio_from_cut(cuts, name, sample_rate): answer_audios, answer_audio_lens, features_lens = load_audio_from_cut( cuts, "target_audio", self.codec_sample_rate ) + # import pdb; pdb.set_trace() # 16k source audio audio = [cut.resample(self.sample_rate).load_audio() for cut in cuts] audio_lens = [torch.tensor(a.shape[1]).long() for a in audio] @@ -256,8 +382,8 @@ def get_step_by_time(text_start_time): cnt = 0 new_target_texts = [] + new_source_texts = [] for i in range(len(num_turns)): - each_target_texts = [] total_steps = ( torch.ceil( answer_audio_lens[i] / self.codec_model_downsampling_factor / self.decoder_reduction_factor @@ -272,25 +398,71 @@ def get_step_by_time(text_start_time): else self.text_processor.tokenizer.unk_id ), ) + cur_source_text = torch.full( + [total_steps], + ( + self.text_processor.tokenizer.pad_id + if self.text_processor.tokenizer.pad_id >= 0 + else self.text_processor.tokenizer.unk_id + ), + ) assert len(text_start_time[i]) == num_turns[i] // 2 for j in range(num_turns[i] // 2): text_start_step = get_step_by_time(text_start_time[i][j]) text_end_step = get_step_by_time(text_end_time[i][j]) + 1 cur_target_text[text_start_step] = self.text_processor.bos_id - text_len = min(text_end_step - text_start_step - 1, target_texts[cnt].shape[0]) - cur_target_text[(text_start_step + 1) : (text_start_step + 1 + text_len)] = target_texts[cnt][ - :text_len - ] + cur_source_text[text_start_step] = self.text_processor.bos_id + # import pdb; pdb.set_trace() + if getattr(cut, "s2s_duplex", False): + # Note: text can be truncated + # logging.debug(f'target_text before truncation: {target_texts[cnt]}') + text_len = min(text_end_step - text_start_step - 1, target_texts[cnt].shape[0]) + cur_target_text[(text_start_step + 1) : (text_start_step + 1 + text_len)] = target_texts[cnt][ + :text_len + ] + src_text_len = min(text_end_step - text_start_step - 1, source_texts[cnt].shape[0]) + cur_source_text[(text_start_step + 1) : (text_start_step + 1 + src_text_len)] = source_texts[cnt][ + :src_text_len + ] + # logging.debug(f'target_text after truncation: {target_texts[cnt][:text_len]}') + # logging.debug(f'source_text after truncation: {source_texts[cnt][:text_len]}') + elif getattr(cut, "s2s_duplex_align", False): + # logging.debug(f'target_text before expansion: {target_texts[cnt]}') + text_len_plus_eos = torch.tensor(text_end_step - text_start_step) + target_texts_expanded = self._expand_text_with_timestamps_and_word_lengths( + [target_texts[cnt]], + [word_lengths[cnt]], + [start_time_tokens[cnt]], + [text_len_plus_eos], + self.codec_model_downsampling_factor / self.codec_sample_rate, + pad_id=self.text_processor.unk_id) + # logging.debug(f'target_text after expansion: {target_texts_expanded[0]}') + # text_len = min(text_end_step - text_start_step - 1, target_texts_expanded[0].shape[0]) + # import pdb; pdb.set_trace() + logging.debug(f'start_time_token: {start_time_tokens[cnt]}') + logging.debug(f'target_tokens: {target_texts[cnt]}') + logging.debug(f'word_length: {word_lengths[cnt]}') + logging.debug(f'target_texts_expanded: {target_texts_expanded[0]}') + cur_target_text[(text_start_step + 1) : (text_start_step + 1 + text_len_plus_eos)] = target_texts_expanded[0] + else: + raise Exception("Undefined assistant channel text format.") + cur_target_text[text_end_step] = self.text_processor.eos_id + cur_source_text[text_end_step] = self.text_processor.eos_id cnt += 1 new_target_texts.append(cur_target_text) + new_source_texts.append(cur_source_text) target_texts_merge, target_text_lengths = collate_and_pad(new_target_texts) + source_texts_merge, source_text_lengths = collate_and_pad(new_source_texts) assert cnt == len(target_texts) assert target_texts_merge.shape[0] == len(num_turns) + assert cnt == len(source_texts) + assert source_texts_merge.shape[0] == len(num_turns) # Merge batch # note: the codec id in labels and contexts and others do not consider the offset e.g. speech_eos is 1002 # the offset is all considered by SumVocabParallelEmbedding + # import pdb; pdb.set_trace() return_batch = { "sample_ids": list(cuts.ids), "audio_signal": audio, @@ -300,6 +472,7 @@ def get_step_by_time(text_start_time): "instructions": None, "tokens": target_texts_merge, # used in _reconfigure_and_process_inference_batch "target_texts_merge": target_texts_merge, # used in prepare_llm_input + "source_texts_merge": source_texts_merge, # used in prepare_llm_input "contexts": target_texts_merge[:, :1], # used in inference "context_lengths": torch.ones_like(target_text_lengths), "target_texts": target_texts_merge, @@ -315,41 +488,9 @@ def get_step_by_time(text_start_time): def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]: import re - if getattr(cuts[0], "s2s_duplex"): + if getattr(cuts[0], "s2s_duplex", False) or getattr(cuts[0], "s2s_duplex_align", False): return self.__getitem__duplex_(cuts) - def extract_text_and_time_tokens(input_sequence): - # Regular expression to match time tokens (e.g., <|x|> where x is an integer) - time_token_pattern = r"<\|(\d+)\|>" - # Find all time tokens - time_tokens = re.findall(time_token_pattern, input_sequence) - # Only keep the first token of every pair (i.e., start time tokens) - start_time_token = [int(time_tokens[i]) for i in range(0, len(time_tokens), 2)] - # Remove all time tokens to isolate words - words = re.sub(time_token_pattern, '', input_sequence).split() - # Process each word, tokenize it, and calculate token lengths - tokenized_words = [] - word_length = [] - for idx, word in enumerate(words): - # Tokenize the word using the provided text processor - if id == 0: - logging.debug(f'word: {word}') - tokenized_word = self.text_processor._process_example(context="", output=word) - # Remove the EOS token (assuming the EOS token is at the end of "answer_ids") - token_ids = tokenized_word["answer_ids"][:-1] # Remove EOS token - if idx != 0: # If not the first word, remove the first token - token_ids = token_ids[1:] - if id == 0: - logging.debug(f'token_ids: {token_ids}') - token_length = len(token_ids) # Calculate the length - tokenized_words.extend(token_ids) - word_length.append(token_length) - return ( - torch.as_tensor(tokenized_words), - torch.as_tensor(start_time_token), - torch.as_tensor(word_length), - ) - cuts = cuts.sort_by_duration() logging.debug(f"Len: {len(cuts)}") @@ -388,7 +529,7 @@ def extract_text_and_time_tokens(input_sequence): target_text["answer_ids"][:-1] ), torch.as_tensor(len(target_text["answer_ids"]) - 1) else: - target_text, start_time_token, word_length = extract_text_and_time_tokens(text) + target_text, start_time_token, word_length = self._extract_text_and_time_tokens(text) target_text_length = len(target_text) else: raise Exception("Second speaker should be agent") @@ -559,75 +700,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): texts_expanded = texts_expanded[:, :-1] return texts, text_lengths, texts_expanded - def discretize_time(start_token, speech_resolution=0.08, timestamp_resolution=0.08): - """Convert the start token into a time index based on the resolution.""" - return int(start_token * timestamp_resolution / speech_resolution) - - def _expand_text_with_timestamps_and_word_lengths( - word_tokens, word_lengths, start_time_tokens, features_lens, frame_rate=0.08, pad_id=None - ): - """ - Expand word tokens according to start time tokens and word lengths for a batch of sequences. - - Args: - - word_tokens: List of lists of token sequences (each inner list is a word's token IDs), shape [batch][time]. - - word_lengths: List of lists of word lengths, shape [batch][time]. - - start_time_tokens: List of lists of start times, shape [batch][time]. - - max_length: Maximum length in the time dimension (number of frames). - - frame_rate: Frame rate resolution. - - pad_id: Padding ID to use for empty positions in the tensor. - - Returns: - - 2D tensor [batch, max_length] where each row is the expanded token sequence for that batch. - """ - if pad_id is None: - raise ValueError("pad_id must be provided.") - - batch_size = len(word_tokens) - max_length = max(features_lens).item() - - # Create the empty 2D tensor [batch, max_length] with pad_id as the default value - texts_expanded = torch.full((batch_size, max_length), fill_value=pad_id, dtype=torch.long) - - # Iterate over each batch - for batch_idx in range(batch_size): - # Remove the speech eos - batch_max_length = features_lens[batch_idx] - 1 - word_start_idx = 0 # Start index to keep track of the position within the concatenated word tokens - - # Iterate over the words in the current batch - for word_idx, word_length in enumerate(word_lengths[batch_idx]): - start_token = start_time_tokens[batch_idx][word_idx] - - # Convert the start time token into a time index based on frame rate - start_time_index = discretize_time(start_token, frame_rate) - - # Reduction of start time index due to stacking of frames - start_time_index = int(start_time_index / self.decoder_reduction_factor) - if batch_idx == 0: - logging.debug(f'start_time_index[0]: {start_time_index}') - - # Calculate the end time index based on word length - end_time_index = start_time_index + word_length - end_time_index = min(end_time_index, max_length) # Ensure it doesn't exceed max length - - # Get the word tokens for the current word - word_token_ids = word_tokens[batch_idx][word_start_idx : word_start_idx + word_length] - - # Populate the tokens in the expanded tensor at the correct positions - for t_idx in range(start_time_index, end_time_index): - if t_idx - start_time_index < len(word_token_ids): # Ensure tokens are within bounds - token_id = word_token_ids[t_idx - start_time_index] # Get token for this time step - texts_expanded[batch_idx][t_idx] = token_id # Directly assign the token ID - - # Move to the next word in the concatenated word tokens - word_start_idx += word_length - - # Overwrite padding tokens - texts_expanded[batch_idx][batch_max_length:] = text_pad_id - - return texts_expanded - batch_size = audio.shape[0] # TODO: can remove the following except features_lens target_codec = get_3d_empty_tensor(batch_size, max(features_lens).item() + 1, text_pad_id, self.speech_pad_id) @@ -684,7 +756,7 @@ def _expand_text_with_timestamps_and_word_lengths( bos_tensor[:, :, 0] = self.text_processor.bos_id # [batch, max_feat_len] # the only thing needed is features_lens which can be estimated from target_audio length - target_texts_expanded = _expand_text_with_timestamps_and_word_lengths( + target_texts_expanded = self._expand_text_with_timestamps_and_word_lengths( unpadded_target_texts, word_lengths, start_time_tokens,