From 713c931d54e775fe8562f91f3e3ec39c0df06e91 Mon Sep 17 00:00:00 2001 From: zhehuaichen Date: Tue, 12 Nov 2024 10:37:37 -0800 Subject: [PATCH] copy over the changes from Kevin about align_s2s documented in https://docs.google.com/document/d/1LRMwpUt8TH96oRVi0zz7OOTbejrMk9zs4VH1SEX59ZM/edit?tab=t.0 and /lustre/fsw/portfolios/llmservice/users/kevinhu/works/mod_speech_llm/code/NeMo_s2s_align_debug/ Signed-off-by: zhehuaichen --- .../speech_llm/conf/s2s/pt_salm_1a.yaml | 52 ++- .../speech_llm/conf/s2s/pt_salm_1a_s2s.yaml | 431 ++++++++++++++++++ .../conf/s2s/pt_salm_1a_s2s_direct.yaml | 431 ++++++++++++++++++ .../speech_llm/conf/s2s/pt_salm_1b.yaml | 15 +- .../speech_llm/modular_audio_gpt_train.py | 4 + .../speech_llm/data/lhotse_dataset.py | 196 +++++++- .../speech_llm/models/modular_models.py | 9 + .../speech_llm/models/modular_s2s_models.py | 25 +- .../common/audio_text_generation_strategy.py | 1 + .../common/audio_text_generation_utils.py | 4 + .../speech_llm/parts/utils/data_utils.py | 4 + 11 files changed, 1128 insertions(+), 44 deletions(-) create mode 100644 examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s.yaml create mode 100644 examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s_direct.yaml diff --git a/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a.yaml b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a.yaml index aa425d1b9..ca375d9a5 100644 --- a/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a.yaml +++ b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a.yaml @@ -238,35 +238,40 @@ model: train_ds: input_cfg: - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/es/ + # shar_path: /workspace/data/s2s_shars/es/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/ weight: 1.0 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/msmacro/ + # shar_path: /workspace/data/s2s_shars/msmacro/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/ weight: 0.2 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/alpaca/ + # shar_path: /workspace/data/s2s_shars/alpaca/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/ weight: 0.1 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/squadv2/ + # shar_path: /workspace/data/s2s_shars/squadv2/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/ weight: 0.05 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ + # shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/ weight: 0.2 tags: lang: en - s2s: True + s2s_align: True global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: True @@ -326,38 +331,43 @@ model: weight: 1.0 tags: lang: en - s2s: True + s2s_align: True input_cfg: - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/es_validation/ # ST + # shar_path: /workspace/data/s2s_shars/es_validation/ # ST + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/ weight: 1.0 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/msmacro/ # text SQA + # shar_path: /workspace/data/s2s_shars/msmacro/ # text SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/ weight: 1.0 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/alpaca/ # text SQA + # shar_path: /workspace/data/s2s_shars/alpaca/ # text SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/ weight: 1.0 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/squadv2/ # speech SQA + # shar_path: /workspace/data/s2s_shars/squadv2/ # speech SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/ weight: 0.1 tags: lang: en - s2s: True + s2s_align: True - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ # speech SQA + # shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ # speech SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/ weight: 0.9 tags: lang: en - s2s: True + s2s_align: True global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} diff --git a/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s.yaml b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s.yaml new file mode 100644 index 000000000..2639206e0 --- /dev/null +++ b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s.yaml @@ -0,0 +1,431 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_audio_gpt_s2s_lhotse + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +model_target: nemo.collections.multimodal.speech_llm.models.modular_s2s_models.S2sModularAudioGPTModel + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metrics[0].name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + # TBD: + salm_model_path: ??? + speech_pad_id: 1001 + speech_unk_id: 1002 + speech_bos_id: 1003 + speech_eos_id: 1004 + # proj_head_dims[0] should be set to LLM vocab size (32k below) + # proj_head_dims[1:] should be set to nemo codec vocab size (1024 below) + proj_head_dims: [32000, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024] + s2s_vocab_size: 72960 # This should be equal to the proj_head_dims[0] + decoder_reduction_factor * sum(proj_head_dims[1:]) + proj_head_loss_weights: [1, 1, 1, 1, 1, 1, 1, 1, 1] + decoder_reduction_factor: 5 + + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + ## Legacy batch_size configuration + # When used with lhotse, the batch composition is decided by dataloader configs + # and batch size here is only used for deciding gradient accumulation. + # gradient accumulation = global_batch_size / micro_batch_size / data_parallel_size + # where data_parallel_size = num_nodes * num_gpus / TP_size + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # for MegatronNMTMultiProjModel and datasets + + peft: + peft_scheme: "lora" # can be either lora, adapter, ia3 or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + multi_layer_feat: + layer_idx_list: [0,16] # layer indices to extract features from + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + codec_model_path: ??? + asr_model_path: "stt_en_fastconformer_transducer_large" + + data: + end_string: "" + train_ds: + input_cfg: + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/es/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/ + weight: 1.0 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/ + weight: 0.2 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/alpaca/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/ + weight: 0.1 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/squadv2/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/ + weight: 0.05 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/ + weight: 0.2 + tags: + lang: en + s2s: True + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 512 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + ali_score_key: 'ali_score' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: "{context}\n{answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{context}[/INST] {answer}" + speech_pad_id: ${model.speech_pad_id} + speech_unk_id: ${model.speech_unk_id} + speech_bos_id: ${model.speech_bos_id} + speech_eos_id: ${model.speech_eos_id} + filter_by_source_target_text_ratio: False # If the length of the source and target texts differ too much, the instance will be discarded + source_target_text_ratio_limit: 4.0 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + # tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + use_lhotse: True + text_field : "text" + batch_duration : 80 # 0 + quadratic_duration : 30 + num_buckets : 30 + buffer_size : 10000 + shuffle_buffer_size : 10000 + duration_bins: null + + validation_ds: + names: s2st + input_cfg: + - type: group + weight: 1.0 + tags: + lang: en + s2s: True + input_cfg: + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/es_validation/ # ST + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/ + weight: 1.0 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro/ # text SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/ + weight: 1.0 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/alpaca/ # text SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/ + weight: 1.0 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/squadv2/ # speech SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/ + weight: 0.1 + tags: + lang: en + s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ # speech SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/ + weight: 0.9 + tags: + lang: en + s2s: True + + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 512 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + ali_score_key: ${model.data.train_ds.ali_score_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: True + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 256 + speech_pad_id: ${model.speech_pad_id} + speech_unk_id: ${model.speech_unk_id} + speech_bos_id: ${model.speech_bos_id} + speech_eos_id: ${model.speech_eos_id} + filter_by_source_target_text_ratio: False # If the length of the source and target texts differ too much, the instance will be discarded + source_target_text_ratio_limit: 4.0 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 10 + metrics: + - name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + - name: "bleu" + average: null + num_classes: null + - name: "asr-bleu" + average: null + num_classes: null + - name: "asr-wer" + average: null + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s_direct.yaml b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s_direct.yaml new file mode 100644 index 000000000..dd5ef6cba --- /dev/null +++ b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s_direct.yaml @@ -0,0 +1,431 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_audio_gpt_s2s_lhotse + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +model_target: nemo.collections.multimodal.speech_llm.models.modular_s2s_models.S2sModularAudioGPTModel + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metrics[0].name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + # TBD: + salm_model_path: ??? + speech_pad_id: 1001 + speech_unk_id: 1002 + speech_bos_id: 1003 + speech_eos_id: 1004 + # proj_head_dims[0] should be set to LLM vocab size (32k below) + # proj_head_dims[1:] should be set to nemo codec vocab size (1024 below) + proj_head_dims: [32000, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024] + s2s_vocab_size: 72960 # This should be equal to the proj_head_dims[0] + decoder_reduction_factor * sum(proj_head_dims[1:]) + proj_head_loss_weights: [1, 1, 1, 1, 1, 1, 1, 1, 1] + decoder_reduction_factor: 5 + + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + ## Legacy batch_size configuration + # When used with lhotse, the batch composition is decided by dataloader configs + # and batch size here is only used for deciding gradient accumulation. + # gradient accumulation = global_batch_size / micro_batch_size / data_parallel_size + # where data_parallel_size = num_nodes * num_gpus / TP_size + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # for MegatronNMTMultiProjModel and datasets + + peft: + peft_scheme: "lora" # can be either lora, adapter, ia3 or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + multi_layer_feat: + layer_idx_list: [0,16] # layer indices to extract features from + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + codec_model_path: ??? + asr_model_path: "stt_en_fastconformer_transducer_large" + + data: + end_string: "" + train_ds: + input_cfg: + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/es/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/ + weight: 1.0 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/ + weight: 0.2 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/alpaca/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/ + weight: 0.1 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/squadv2/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/ + weight: 0.05 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/ + weight: 0.2 + tags: + lang: en + direct_s2s: True + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 512 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + ali_score_key: 'ali_score' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: "{context}\n{answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{context}[/INST] {answer}" + speech_pad_id: ${model.speech_pad_id} + speech_unk_id: ${model.speech_unk_id} + speech_bos_id: ${model.speech_bos_id} + speech_eos_id: ${model.speech_eos_id} + filter_by_source_target_text_ratio: False # If the length of the source and target texts differ too much, the instance will be discarded + source_target_text_ratio_limit: 4.0 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + # tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + use_lhotse: True + text_field : "text" + batch_duration : 80 # 0 + quadratic_duration : 30 + num_buckets : 30 + buffer_size : 10000 + shuffle_buffer_size : 10000 + duration_bins: null + + validation_ds: + names: s2st + input_cfg: + - type: group + weight: 1.0 + tags: + lang: en + direct_s2s: True + input_cfg: + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/es_validation/ # ST + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/ + weight: 1.0 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro/ # text SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/ + weight: 1.0 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/alpaca/ # text SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/ + weight: 1.0 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/squadv2/ # speech SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/ + weight: 0.1 + tags: + lang: en + direct_s2s: True + - type: lhotse_shar + # shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ # speech SQA + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/ + weight: 0.9 + tags: + lang: en + direct_s2s: True + + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 512 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + ali_score_key: ${model.data.train_ds.ali_score_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: True + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 256 + speech_pad_id: ${model.speech_pad_id} + speech_unk_id: ${model.speech_unk_id} + speech_bos_id: ${model.speech_bos_id} + speech_eos_id: ${model.speech_eos_id} + filter_by_source_target_text_ratio: False # If the length of the source and target texts differ too much, the instance will be discarded + source_target_text_ratio_limit: 4.0 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 10 + metrics: + - name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + - name: "bleu" + average: null + num_classes: null + - name: "asr-bleu" + average: null + num_classes: null + - name: "asr-wer" + average: null + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/s2s/pt_salm_1b.yaml b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1b.yaml index cadf028e9..ccea6633f 100644 --- a/examples/multimodal/speech_llm/conf/s2s/pt_salm_1b.yaml +++ b/examples/multimodal/speech_llm/conf/s2s/pt_salm_1b.yaml @@ -14,6 +14,8 @@ name: megatron_audio_gpt_s2s_lhotse +log_level: INFO + trainer: devices: 1 accelerator: gpu @@ -236,11 +238,12 @@ model: train_ds: input_cfg: - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/es/ + # shar_path: /workspace/data/s2s_shars/es/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/src_timestamp weight: 1.0 tags: lang: en - s2s: True + s2s_align: True global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: True @@ -300,14 +303,15 @@ model: weight: 1.0 tags: lang: en - s2s: True + s2s_align: True input_cfg: - type: lhotse_shar - shar_path: /workspace/data/s2s_shars/es_validation/ + # shar_path: /workspace/data/s2s_shars/es_validation/ + shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/src_timestamp weight: 1.0 tags: lang: en - s2s: True + s2s_align: True global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False @@ -338,7 +342,6 @@ model: source_target_text_ratio_limit: 4.0 # ASR configs sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} - log_every_n_steps: 10 metrics: - name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/examples/multimodal/speech_llm/modular_audio_gpt_train.py b/examples/multimodal/speech_llm/modular_audio_gpt_train.py index ad8aacef2..07997d789 100644 --- a/examples/multimodal/speech_llm/modular_audio_gpt_train.py +++ b/examples/multimodal/speech_llm/modular_audio_gpt_train.py @@ -47,6 +47,10 @@ @hydra_runner(config_path="conf", config_name="modular_audio_gpt_config_peft") def main(cfg) -> None: + # Set up logging with the specified log level + logging_level = getattr(logging, cfg.log_level.upper(), logging.INFO) + logging.setLevel(logging_level) + logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 65726bae8..eba36b0ba 100644 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -1,5 +1,5 @@ -import logging import random +import re import torch.utils.data from lhotse import CutSet @@ -12,6 +12,7 @@ build_loss_mask, ceil_to_nearest, ) +from nemo.utils import logging def collate_vectors(items, max_length: int, padding_value): @@ -108,25 +109,79 @@ def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]: instructions, instruction_lengths = [], [] source_texts, source_text_lengths = [], [] # Not used in the current implementation target_texts, target_text_lengths = [], [] + start_time_tokens, word_lengths = [], [] remove_ids = [] for id, cut in enumerate(cuts): - metadata.append({'audio_filepath': cut.id + '.wav'}) + if id == 0: + logging.debug(f'audio_filepath: {cut.id}.wav') + logging.debug(f'cut: {cut}') + metadata.append({'audio_filepath': cut.id + '.wav'}) # TODO: the following use of _process_example is not ideal. Should update instruction = self.text_processor._process_example(context=cut.supervisions[0].text, output="") instruction, instruction_length = torch.as_tensor(instruction["input_ids"][:-1]), torch.as_tensor( len(instruction["input_ids"]) - 1 ) + if id == 0: + logging.debug(f'instruction: {cut.supervisions[0].text}') source_text = self.text_processor._process_example(context=cut.supervisions[1].text, output="") source_text, source_text_length = torch.as_tensor(source_text["input_ids"]), torch.as_tensor( len(source_text["input_ids"]) ) - - target_text = self.text_processor._process_example(context="", output=cut.supervisions[2].text) - # -1 to remove the eos token added by the text processor - target_text, target_text_length = torch.as_tensor(target_text["answer_ids"][:-1]), torch.as_tensor( - len(target_text["answer_ids"]) - 1 - ) + if id == 0: + logging.debug(f'source_text: {cut.supervisions[1].text}') + + def extract_text_and_time_tokens(input_sequence): + # Regular expression to match time tokens (e.g., <|x|> where x is an integer) + time_token_pattern = r"<\|(\d+)\|>" + # Find all time tokens + time_tokens = re.findall(time_token_pattern, input_sequence) + # Only keep the first token of every pair (i.e., start time tokens) + start_time_token = [int(time_tokens[i]) for i in range(0, len(time_tokens), 2)] + # Remove all time tokens to isolate words + words = re.sub(time_token_pattern, '', input_sequence).split() + # Process each word, tokenize it, and calculate token lengths + tokenized_words = [] + word_length = [] + for idx, word in enumerate(words): + # Tokenize the word using the provided text processor + if id == 0: + logging.debug(f'word: {word}') + tokenized_word = self.text_processor._process_example(context="", output=word) + # Remove the EOS token (assuming the EOS token is at the end of "answer_ids") + token_ids = tokenized_word["answer_ids"][:-1] # Remove EOS token + if idx != 0: # If not the first word, remove the first token + token_ids = token_ids[1:] + if id == 0: + logging.debug(f'token_ids: {token_ids}') + token_length = len(token_ids) # Calculate the length + tokenized_words.extend(token_ids) + word_length.append(token_length) + return ( + torch.as_tensor(tokenized_words), + torch.as_tensor(start_time_token), + torch.as_tensor(word_length), + ) + + # import pdb; pdb.set_trace() + use_timestamp = getattr(cut, "s2s_align", False) + if not use_timestamp: + pattern = r"<\|\d+\|>" + output_text = re.sub(pattern, "", cut.supervisions[2].text) + output_text = re.sub(r'\s+', ' ', output_text).strip() + target_text = self.text_processor._process_example(context="", output=output_text) + # -1 to remove the eos token added by the text processor + target_text, target_text_length = torch.as_tensor(target_text["answer_ids"][:-1]), torch.as_tensor( + len(target_text["answer_ids"]) - 1 + ) + if id == 0: + logging.debug(f'target_text: {output_text}') + else: + target_text, start_time_token, word_length = extract_text_and_time_tokens(cut.supervisions[2].text) + target_text_length = len(target_text) + # import pdb; pdb.set_trace() + if id == 0: + logging.debug(f'target_text: {cut.supervisions[2].text}') if self.filter_by_source_target_text_ratio: if ( @@ -142,6 +197,9 @@ def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]: source_text_lengths.append(source_text_length) target_texts.append(target_text) target_text_lengths.append(target_text_length) + if use_timestamp: + word_lengths.append(word_length) + start_time_tokens.append(start_time_token) cuts = [c for i, c in enumerate(cuts) if i not in remove_ids] @@ -215,6 +273,8 @@ def collate_and_pad(inputs): # Loop through cuts and build target_codec for i, cut in enumerate(cuts): feat_i = cut.target_codes.load() + # logging.debug(f'frame_shift: {cut.target_codes.frame_shift}') + # logging.debug(f'feat_i.shape: {feat_i.shape}') target_codec[i, : feat_i.shape[0], 0] = text_unk_id feat_i = feat_i[: features_lens[i] * self.decoder_reduction_factor, : self.n_speech_codebooks] feat_i = feat_i.reshape((-1, self.n_speech_codebooks * self.decoder_reduction_factor)) @@ -222,6 +282,8 @@ def collate_and_pad(inputs): target_codec[i, feat_i.shape[0], :] = eos_tensor target_codec = target_codec.to(torch.int) + logging.debug(f'target_codec.shape: {target_codec.shape} ') + logging.debug(f'features_lens.shape: {features_lens.shape} ') source_texts, source_text_lengths = collate_and_pad(source_texts) @@ -243,6 +305,9 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): texts_expanded = texts_expanded[:, :-1] return texts, text_lengths, texts_expanded + # import pdb; pdb.set_trace() + + unpadded_target_texts = target_texts target_texts, target_text_lengths, target_texts_expanded = _convert_text_to_3d_tensor(target_texts) instructions, instruction_lengths, instructions_expanded_no_eos = _convert_text_to_3d_tensor( # tokens_to_generate is used in inference @@ -253,12 +318,125 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0): # answers = torch.concat([speaker_context, bos_tensor, target_codec], 1) - if getattr(cut, "s2s", False): + logging.debug(f'target_texts_expanded.shape: {target_texts_expanded.shape} ') + logging.debug(f'target_text_lengths.shape: {target_text_lengths.shape} ') + logging.debug(f'cut: {cut} ') + + def discretize_time(start_token, speech_resolution=0.08, timestamp_resolution=0.08): + """Convert the start token into a time index based on the resolution.""" + return int(start_token * timestamp_resolution / speech_resolution) + + def _expand_text_with_timestamps_and_word_lengths( + word_tokens, word_lengths, start_time_tokens, features_lens, frame_rate=0.08, pad_id=None + ): + """ + Expand word tokens according to start time tokens and word lengths for a batch of sequences. + + Args: + - word_tokens: List of lists of token sequences (each inner list is a word's token IDs), shape [batch][time]. + - word_lengths: List of lists of word lengths, shape [batch][time]. + - start_time_tokens: List of lists of start times, shape [batch][time]. + - max_length: Maximum length in the time dimension (number of frames). + - frame_rate: Frame rate resolution. + - pad_id: Padding ID to use for empty positions in the tensor. + + Returns: + - 2D tensor [batch, max_length] where each row is the expanded token sequence for that batch. + """ + if pad_id is None: + raise ValueError("pad_id must be provided.") + + batch_size = len(word_tokens) + max_length = max(features_lens).item() + + # Create the empty 2D tensor [batch, max_length] with pad_id as the default value + texts_expanded = torch.full((batch_size, max_length), fill_value=pad_id, dtype=torch.long) + + # Iterate over each batch + for batch_idx in range(batch_size): + batch_max_length = features_lens[batch_idx] + word_start_idx = 0 # Start index to keep track of the position within the concatenated word tokens + + # Iterate over the words in the current batch + for word_idx, word_length in enumerate(word_lengths[batch_idx]): + start_token = start_time_tokens[batch_idx][word_idx] + + # Convert the start time token into a time index based on frame rate + start_time_index = discretize_time(start_token, frame_rate) + + # Reduction of start time index due to stacking of frames + start_time_index = int(start_time_index / self.decoder_reduction_factor) + if batch_idx == 0: + logging.debug(f'start_time_index[0]: {start_time_index}') + + # Calculate the end time index based on word length + end_time_index = start_time_index + word_length + end_time_index = min(end_time_index, max_length) # Ensure it doesn't exceed max length + + # Get the word tokens for the current word + word_token_ids = word_tokens[batch_idx][word_start_idx : word_start_idx + word_length] + + # Populate the tokens in the expanded tensor at the correct positions + for t_idx in range(start_time_index, end_time_index): + if t_idx - start_time_index < len(word_token_ids): # Ensure tokens are within bounds + token_id = word_token_ids[t_idx - start_time_index] # Get token for this time step + texts_expanded[batch_idx][t_idx] = token_id # Directly assign the token ID + + # Move to the next word in the concatenated word tokens + word_start_idx += word_length + + # Overwrite padding tokens + texts_expanded[batch_idx][batch_max_length:] = text_pad_id + + return texts_expanded + + # import pdb; pdb.set_trace() + + # TODO(huk): Consider smaller reduction factor + if getattr(cut, "s2s_align", False): + max_feat_len = max(features_lens).item() + 1 + # [batch, max_feat_len] + target_text_expanded = _expand_text_with_timestamps_and_word_lengths( + unpadded_target_texts, + word_lengths, + start_time_tokens, + features_lens + 1, + cut.target_codes.frame_shift, + pad_id=text_unk_id, + ) + # import pdb; pdb.set_trace() + logging.debug(f'start_time_token: {start_time_tokens[0]}') + logging.debug(f'word_length: {word_lengths[0]}') + logging.debug(f'target_tokens: {unpadded_target_texts[0]}') + logging.debug(f'target_text_expanded: {target_text_expanded[0,:]}') + # [batch, max_feat_len, 1+V], where V = #codebooks * reduction_factor + target_codec[:, :, 0] = target_text_expanded + token_list = target_codec + + logging.debug(f'token_list[0].shape: {token_list[0].shape}') + if not self.t5_style: + token_list = [ + torch.concat([it[:itl], tt], 0) + for tt, it, itl in zip(token_list, instructions_expanded_no_eos, instruction_lengths) + ] + tokens, _ = collate_and_pad(token_list) + speech_loss_mask = tokens[:, :, 1:] != self.speech_pad_id + # Make the text loss mask the same as speech since they are aligned + loss_mask = torch.cat([speech_loss_mask[..., :1], speech_loss_mask], dim=-1) + if not self.t5_style: + for itl in instruction_lengths: + loss_mask[:, :itl, :] = False + # loss_mask = torch.cat([text_loss_mask, speech_loss_mask], 2) + # full_lengths = target_text_lengths + 1 + features_lens + 1 + instruction_length + full_lengths = features_lens + 1 + instruction_length + elif getattr(cut, "s2s", False): # Add 1 for eos token token_list = [ torch.concat([tt[: ttl + 1], tc[: tcl + 1]], 0) for tt, ttl, tc, tcl in zip(target_texts_expanded, target_text_lengths, target_codec, features_lens) ] + # import pdb; pdb.set_trace() + logging.debug(f'token_list[0].shape: {token_list[0].shape}') if not self.t5_style: token_list = [ torch.concat([it[:itl], tt], 0) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 2dd838e54..b15a851f5 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -14,6 +14,7 @@ import itertools import json +import logging import os from functools import partial from typing import List, Optional, Union @@ -289,6 +290,8 @@ def inject_perception_input( attention_mask = self._create_attention_mask(encoder_input) position_ids = build_position_ids(encoder_input[:, :, 0]) + # import pdb; pdb.set_trace() + # Add position embeddings if ( getattr(lm_embedding, "position_embeddings", None) is not None @@ -328,6 +331,7 @@ def _get_text_embeddings(self, text_tokens, position_ids): def prepare_llm_input(self, audio_batch): """Prepare input for the LLM.""" input_signal = audio_batch['audio_signal'] + logging.debug(f'input_signal.shape: {input_signal.shape}') input_signal_length = audio_batch['audio_signal_length'] input_ids, input_length, labels, loss_mask = ( @@ -348,6 +352,9 @@ def prepare_llm_input(self, audio_batch): processed_signal_length=None, ) + logging.debug(f'encoded.shape: {encoded.shape}') + logging.debug(f'encoded_len.shape: {encoded_len.shape}') + logging.debug(f'num_audios: {num_audios}') if num_audios is not None: # split the encoded and encoded_len by num_audios, used when there're multiple audio files per sample encoded = encoded.split(num_audios.tolist()) @@ -604,6 +611,8 @@ def fwd_output_only_func(dataloader_iter, model): **extra_arg, ) + # import pdb; pdb.set_trace() + # Advance inference sequence offset. if self.inference_params: # if last stage, then (final) output is [b, s, h], otherwise it's [s, b, h] diff --git a/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py b/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py index e97a75b53..7bf6b33ca 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_s2s_models.py @@ -62,6 +62,7 @@ def __init__( self.proj_head_dims = proj_head_dims def forward(self, input_): + if input_.ndim == 3: assert input_.shape[2] == len(self.proj_head_dims) input_ = input_.clone() @@ -311,6 +312,8 @@ def inference_step(self, dataloader_iter, mode): """ Used for validation and test steps, added postprocessing after calling self.predict_step(). """ + # import pdb; pdb.set_trace() + # Evaluation of multimodal data follows the same pattern as training except predict_step batch, batch_idx, dataloader_idx = next(dataloader_iter) data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds @@ -403,15 +406,20 @@ def parse_decoder_outputs( decoder_output = input_decoder_output[-1:].tile([max_len, 1]) decoder_output[: max_len - context_length] = input_decoder_output[context_length:] + # Do not split because text and speech are now aligned # Split text and speech part based on the position of the first separator token - sep_pos = (decoder_output[:, 0] == text_separator).long() - if torch.any(sep_pos): - first_sep_pos = torch.argmax(sep_pos) - text_tokens = decoder_output[:first_sep_pos, 0] - speech_tokens = decoder_output[first_sep_pos + 1 :, 1:] - else: - text_tokens = decoder_output[:, 0] - speech_tokens = decoder_output[:, 1:] + # sep_pos = (decoder_output[:, 0] == text_separator).long() + # if torch.any(sep_pos): + # first_sep_pos = torch.argmax(sep_pos) + # text_tokens = decoder_output[:first_sep_pos, 0] + # speech_tokens = decoder_output[first_sep_pos + 1 :, 1:] + # else: + # text_tokens = decoder_output[:, 0] + # speech_tokens = decoder_output[:, 1:] + text_tokens = decoder_output[:, 0] + speech_tokens = decoder_output[:, 1:] + + # import pdb; pdb.set_trace() # Get speech token ids n_speech_codebooks = self.model.n_proj_heads - 1 @@ -816,6 +824,7 @@ def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir def de_concat_multiproj_logits(self, logits): logits_list = [] prev = 0 + # import pdb; pdb.set_trace() for i in self.model.proj_head_dims: logits_list.append(logits[:, prev : prev + i]) prev += i diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py index fa62a01a7..a1395f90b 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py @@ -140,6 +140,7 @@ def end_of_generation_condition( returns: a boolean tensor indicating whether the generation should stop """ + # import pdb; pdb.set_trace() if len(end_strings) == 1 and end_strings[0] == END_OF_SEQ: return prev == eod_id else: diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py index 82de9b723..a6e16ebaf 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py @@ -787,6 +787,8 @@ def get_prev(logits, started, temperature, extra): prev = torch.multinomial(probs, num_samples=1).view(-1) return prev + # import pdb; pdb.set_trace() + prev = [get_prev(logits_i, started, temperature, extra) for logits_i in logits] prev = torch.stack(prev, dim=1) started_expand = started.unsqueeze(1).expand(-1, prev.size(1)) @@ -833,6 +835,8 @@ def get_prev(logits, started, temperature, extra): model.cfg.speech_eos_id, ) + # import pdb; pdb.set_trace() + done_token = done_token.byte() & started.byte() just_finished = (done_token & ~is_done).bool() diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py index 7fc160a8e..46613f574 100644 --- a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -310,6 +310,8 @@ def _process_example(self, context: str, output: str): else: text = context + ' ' + output + # logging.debug(f'text: {text}') + if self.virtual_tokens: # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context # these pad/eos tokens are placeholders for virtual tokens @@ -321,6 +323,8 @@ def _process_example(self, context: str, output: str): if self.end_string: answer_ids += self.tokenizer.text_to_ids(self.end_string) + # logging.debug(f'answer_text: {answer_text}') + if self.audio_locator is None: # signle audio case context_ids = self.tokenizer.text_to_ids(context)