Skip to content

Commit

Permalink
add inference support for align_s2s, s2s, and direct_s2s. removed com…
Browse files Browse the repository at this point in the history
…ments
  • Loading branch information
kevinhu-nv committed Nov 21, 2024
1 parent e7fffd7 commit 55cb4ce
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 38 deletions.
22 changes: 0 additions & 22 deletions nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,11 @@ def extract_text_and_time_tokens(input_sequence):
word_length = []
for idx, word in enumerate(words):
# Tokenize the word using the provided text processor
if id == 0:
logging.debug(f'word: {word}')
tokenized_word = self.text_processor._process_example(context="", output=word)
# Remove the EOS token (assuming the EOS token is at the end of "answer_ids")
token_ids = tokenized_word["answer_ids"][:-1] # Remove EOS token
if idx != 0: # If not the first word, remove the first token
token_ids = token_ids[1:]
if id == 0:
logging.debug(f'token_ids: {token_ids}')
token_length = len(token_ids) # Calculate the length
tokenized_words.extend(token_ids)
word_length.append(token_length)
Expand Down Expand Up @@ -184,8 +180,6 @@ def _expand_text_with_timestamps_and_word_lengths(

# Reduction of start time index due to stacking of frames
start_time_index = int(start_time_index / self.decoder_reduction_factor)
if batch_idx == 0:
logging.debug(f'start_time_index[0]: {start_time_index}')

# Calculate the end time index based on word length
end_time_index = start_time_index + word_length
Expand All @@ -210,8 +204,6 @@ def _expand_text_with_timestamps_and_word_lengths(

cuts = cuts.sort_by_duration()

logging.debug(f"Len: {len(cuts)}")

metadata = []
instructions, instruction_lengths = [], []
target_texts, target_text_lengths = [], []
Expand Down Expand Up @@ -311,8 +303,6 @@ def collate_and_pad(inputs):
tokens[i, : token_lengths[i], :] = inputs[i]
return tokens, torch.LongTensor(token_lengths)

# import pdb; pdb.set_trace()

target_codec = None
answer_audios, answer_audio_lens = None, None
assert self.load_answer_audio
Expand Down Expand Up @@ -390,8 +380,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0):
texts_expanded = texts_expanded[:, :-1]
return texts, text_lengths, texts_expanded

# import pdb; pdb.set_trace()

unpadded_target_texts = target_texts
target_texts, target_text_lengths, target_texts_expanded = _convert_text_to_3d_tensor(target_texts)
instructions, instruction_lengths, instructions_expanded_no_eos = _convert_text_to_3d_tensor(
Expand All @@ -405,7 +393,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0):

# TODO: remove the following stanza
if getattr(cut, "s2s", False):
logging.debug(f'target_texts_expanded: {target_texts_expanded[0,:]}')
# Add 1 for eos token
token_list = [
torch.concat([tt[: ttl + 1], tc[: tcl + 1]], 0)
Expand Down Expand Up @@ -446,19 +433,11 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0):
self.codec_model_downsampling_factor / self.sample_rate,
pad_id=text_unk_id,
)
# import pdb; pdb.set_trace()
logging.debug(f'start_time_token: {start_time_tokens[0]}')
logging.debug(f'word_length: {word_lengths[0]}')
logging.debug(f'target_tokens: {unpadded_target_texts[0]}')
logging.debug(f'target_texts_expanded: {target_texts_expanded[0,:]}')
# [batch, max_feat_len, 1+V], where V = #codebooks * reduction_factor
target_codec[:, :, 0] = target_texts_expanded
token_list = torch.concat([bos_tensor, target_codec], 1)
features_lens += 1

# import pdb; pdb.set_trace()

logging.debug(f'token_list[0].shape: {token_list[0].shape}')
if not self.t5_style:
token_list = [
torch.concat([it[:itl], tt], 0)
Expand All @@ -473,7 +452,6 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0):
loss_mask[:, :itl, :] = False
full_lengths = features_lens + 1 + instruction_lengths
target_text_lengths = -1 * torch.ones_like(target_text_lengths)
# import pdb; pdb.set_trace()
elif getattr(cut, "direct_s2s", False):
# Add 1 for eos token
# tt[0] is the bos token
Expand Down
42 changes: 26 additions & 16 deletions nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,27 +399,38 @@ def inference_step(self, dataloader_iter, mode):
return outputs

def parse_decoder_outputs(
self, input_decoder_output, text_separator, context_length, speech_pad_id=1001, speech_eos_id=1004
self, input_decoder_output, text_separator, context_length, speech_pad_id=1001, speech_eos_id=1004, text_pad_id=0,
):
# remove text context
max_len = input_decoder_output.shape[0]
decoder_output = input_decoder_output[-1:].tile([max_len, 1])
decoder_output[: max_len - context_length] = input_decoder_output[context_length:]

# Do not split because text and speech are now aligned
# Split text and speech part based on the position of the first separator token
# sep_pos = (decoder_output[:, 0] == text_separator).long()
# if torch.any(sep_pos):
# first_sep_pos = torch.argmax(sep_pos)
# text_tokens = decoder_output[:first_sep_pos, 0]
# speech_tokens = decoder_output[first_sep_pos + 1 :, 1:]
# else:
# text_tokens = decoder_output[:, 0]
# speech_tokens = decoder_output[:, 1:]
text_tokens = decoder_output[:, 0]
speech_tokens = decoder_output[:, 1:]

# import pdb; pdb.set_trace()
text_channel = decoder_output[:, 0]

# adhoc: Suppose the text_pad_id appear before text_eos in the align_s2s case
sep_indices = (text_channel == text_separator).nonzero(as_tuple=True)[0]
index_sep = sep_indices[0].item() if sep_indices.numel() > 0 else None
pad_indices = (text_channel == text_pad_id).nonzero(as_tuple=True)[0]
index_pad = pad_indices[0].item() if pad_indices.numel() > 0 else None
is_align_s2s = index_pad is not None and (index_sep is None or index_pad < index_sep)
is_align_s2s = True
if is_align_s2s:
text_tokens_with_pads = decoder_output[:, 0]
text_tokens = text_tokens_with_pads[text_tokens_with_pads != text_pad_id]
speech_tokens = decoder_output[:, 1:]
else:
# s2s predicts [text, text_separator] for the first channel
sep_pos = (text_channel == text_separator).long()
is_s2s = torch.any(sep_pos)
if is_s2s:
first_sep_pos = torch.argmax(sep_pos)
text_tokens = decoder_output[:first_sep_pos, 0]
speech_tokens = decoder_output[first_sep_pos + 1 :, 1:]
else:
# direct_s2s
text_tokens = decoder_output[:, 0]
speech_tokens = decoder_output[:, 1:]

# Get speech token ids
n_speech_codebooks = self.model.n_proj_heads - 1
Expand Down Expand Up @@ -824,7 +835,6 @@ def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir
def de_concat_multiproj_logits(self, logits):
logits_list = []
prev = 0
# import pdb; pdb.set_trace()
for i in self.model.proj_head_dims:
logits_list.append(logits[:, prev : prev + i])
prev += i
Expand Down

0 comments on commit 55cb4ce

Please sign in to comment.