Skip to content

Commit

Permalink
Set Minitron width pruning batch size 1 (#11603)
Browse files Browse the repository at this point in the history
Signed-off-by: Keval Morabia <[email protected]>
  • Loading branch information
kevalmorabia97 authored Dec 16, 2024
1 parent 31fc38e commit f640a02
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ jobs:
prune.ffn_hidden_size=192 \
prune.num_attention_heads=2 \
prune.num_query_groups=2 \
prune.hidden_size=null \
prune.hidden_size=128 \
export.save_path=examples/nlp/language_modeling/ci_prune_width.nemo
AFTER_SCRIPT: |
rm -rf examples/nlp/language_modeling/ci_prune_width.nemo
Expand Down
10 changes: 5 additions & 5 deletions examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ inference:
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: false # a flag used to compute logprob of all the input text, a very special case of running inference, default False
batch_size: 64 # batch size for inference
batch_size: 1 # batch size for inference
max_context_length: 512 # max length of the context, input sequence will be truncated if it is longer than this

trainer:
Expand All @@ -24,7 +24,7 @@ model:
tensor_model_parallel_size: 1 # Pruning currently only supports tensor_model_parallel_size=1
pipeline_model_parallel_size: 1
sequence_parallel: false # Sequence parallelism is not supported with pipeline parallelism
restore_from_path: llama3.1-8b-instruct.nemo # Nemo file path
restore_from_path: ??? # Nemo file path

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
Expand All @@ -34,11 +34,11 @@ prune:
calib_dataset: wikitext # wikitext, cnn_dailymail, or a local dataset
num_calib_size: 1024 # number of samples used for calibration
# pruning constraints (null means no pruning)
ffn_hidden_size: 9216 # ffn_hidden_size in the pruned model
ffn_hidden_size: null # ffn_hidden_size in the pruned model
num_attention_heads: null # num_attention_heads in the pruned model
num_query_groups: null # num_query_groups in the pruned model
hidden_size: 3072 # hidden_size (embedding size) in the pruned model
hidden_size: null # hidden_size (embedding size) in the pruned model
num_layers: null # num_layers (depth) in the pruned model

export:
save_path: llama3.1-8b-instruct-pruned.nemo # Path where the pruned model will be saved
save_path: ??? # Path where the pruned model will be saved
11 changes: 6 additions & 5 deletions examples/nlp/language_modeling/megatron_gpt_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import modelopt.torch.prune as mtp
import torch
import torch.multiprocessing as mp
from datasets import load_dataset
from lightning.pytorch.trainer.trainer import Trainer
Expand All @@ -36,7 +35,7 @@
Example usage:
```
python examples/nlp/language_modeling/megatron_gpt_prune.py \
model.restore_from_path=llama3.1-8b-instruct.nemo \
model.restore_from_path=llama3.1-8b.nemo \
model.tensor_model_parallel_size=1 \
model.pipeline_model_parallel_size=8 \
trainer.num_nodes=1 \
Expand All @@ -46,13 +45,14 @@
prune.num_attention_heads=null \
prune.num_query_groups=null \
prune.hidden_size=3072 \
export.save_path=llama3.1-8b-instruct-pruned.nemo
export.save_path=llama3.1-8b-pruned.nemo
```
where tensor_model_parallel_size must be 1 because of the current prune API limitation
where model.tensor_model_parallel_size and inference.batch_size must be 1 because of the current prune API limitation
"""


def get_calib_data_iter(data="wikitext", batch_size=64, calib_size=512, max_sequence_length=512):
def get_calib_data_iter(data="wikitext", batch_size=1, calib_size=1024, max_sequence_length=512):
"""Get a data iterator for calibration."""
if data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
text_column = "text"
Expand All @@ -73,6 +73,7 @@ def get_calib_data_iter(data="wikitext", batch_size=64, calib_size=512, max_sequ

@hydra_runner(config_path="conf", config_name="megatron_gpt_prune")
def main(cfg) -> None:
"""Prune a model using modelopt."""
# Overwrite model config with the one from the model checkpoint and apply pruning modifications
model_cfg = load_config(cfg.model.restore_from_path)
model_cfg.update(cfg.model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
"\n",
"We use the above parameters to get a competitive model for this demonstration. You can use other strategies or parameters from the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) or the [tech report](https://arxiv.org/pdf/2408.11796) for your experiments. \n",
"\n",
"> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model.\n",
"\n",
"> `TIP:` You can increase the ``batch_size`` (upto 1024) to speed up the width-pruning script execution."
"> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model."
]
},
{
Expand All @@ -48,7 +46,7 @@
" model.tensor_model_parallel_size=1 \\\n",
" model.pipeline_model_parallel_size=8 \\\n",
" +model.dist_ckpt_load_strictness=log_all \\\n",
" inference.batch_size=64 \\\n",
" inference.batch_size=1 \\\n",
" trainer.num_nodes=1 \\\n",
" trainer.precision=bf16 \\\n",
" trainer.devices=8 \\\n",
Expand Down

0 comments on commit f640a02

Please sign in to comment.