diff --git a/.gitignore b/.gitignore index 250657f5..968fb052 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ torch_pruning_bak benchmarks/run output acc.png -pretrained \ No newline at end of file +pretrained +data \ No newline at end of file diff --git a/README.md b/README.md index 308e6a41..579d5f3c 100644 --- a/README.md +++ b/README.md @@ -244,7 +244,25 @@ With DepGraph, it is easy to design some "group-level" criteria to estimate the +#### Modify module attributes or forward function +In some implementation, model forward might rely on some static attributes. For example in [``convformer_s18``](https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/metaformer.py#L107) of timm, we have: + +```python +class Scale(nn.Module): + """ + Scale vector by element multiplications. + """ + + def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True): + super().__init__() + self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning + self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) + + def forward(self, x): + return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning +``` +where the ```forward``` function relies on ``self.shape`` during forwarding. But, the true ``self.shape`` changed after pruning, which should be manually updated. ### 3. Save & Load diff --git a/benchmarks/benchmark_importance_criteria.py b/benchmarks/benchmark_importance_criteria.py index 3cf30161..c2ebc6e8 100644 --- a/benchmarks/benchmark_importance_criteria.py +++ b/benchmarks/benchmark_importance_criteria.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt N_batchs = 10 -imagenet_root = '~/Datasets/shared/imagenet/' +imagenet_root = 'data/imagenet' print('Parsing dataset...') train_dst = ImageFolder(os.path.join(imagenet_root, 'train'), transform=T.Compose( [ diff --git a/examples/timm_models/prune_timm_models.py b/examples/timm_models/prune_timm_models.py new file mode 100644 index 00000000..6e3c50ea --- /dev/null +++ b/examples/timm_models/prune_timm_models.py @@ -0,0 +1,65 @@ +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Sequence +import timm +from timm.models.vision_transformer import Attention +import torch_pruning as tp +import argparse + +parser = argparse.ArgumentParser(description='Prune timm models') +parser.add_argument('--model', default=None, type=str, help='model name') +parser.add_argument('--ch_sparsity', default=0.5, type=float, help='channel sparsity') +parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning') +parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning') +parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm') +args = parser.parse_args() + +def main(): + timm_models = timm.list_models() + if args.list_models: + print(timm_models) + if args.model is None: + return + assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device) + + imp = tp.importance.GroupNormImportance() + print("Pruning %s..."%args.model) + + input_size = model.default_cfg['input_size'] + example_inputs = torch.randn(1, *input_size).to(device) + test_output = model(example_inputs) + ignored_layers = [] + for m in model.modules(): + if isinstance(m, nn.Linear) and m.out_features == model.num_classes: + ignored_layers.append(m) + print("Ignore classifier layer: ", m) + + print("========Before pruning========") + print(model) + base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) + pruner = tp.pruner.MagnitudePruner( + model, + example_inputs, + global_pruning=args.global_pruning, # If False, a uniform sparsity will be assigned to different layers. + importance=imp, # importance criterion for parameter selection + iterative_steps=1, # the number of iterations to achieve target sparsity + ch_sparsity=args.ch_sparsity, # target sparsity + ignored_layers=ignored_layers, + ) + for g in pruner.step(interactive=True): + g.prune() + + print("========After pruning========") + print(model) + test_output = model(example_inputs) + pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) + print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9)) + print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6)) + +if __name__=='__main__': + main() \ No newline at end of file diff --git a/examples/timm_models/readme.md b/examples/timm_models/readme.md index 3b262412..41a9b3fb 100644 --- a/examples/timm_models/readme.md +++ b/examples/timm_models/readme.md @@ -1,29 +1,53 @@ # Pruning Models from Timm -## 0. Requirements + +## 0. List all models in Timm ```bash -pip install -r requirements.txt +python prune_timm_models.py --list_models ``` -Tested environment: + +Output: ``` -pytorch==1.12.1 -timm=0.9.2 +['bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', ...] ``` ## 1. Pruning -```python -python timm_pruning.py +Some models might require additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers/prune_timm_vit.py](../transformers/prune_timm_vit.py) for more details. + +```bash +python prune_timm_models.py --model convnext_xxlarge --ch_sparsity 0.5 # --global_pruning ``` #### Outputs: -Prunable: 119 models, -``` - ['beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_femto', 'convnext_femto_ols', 'convnext_large', 'convnext_large_mlp', 'convnext_nano', 'convnext_nano_ols', 'convnext_pico', 'convnext_pico_ols', 'convnext_small', 'convnext_tiny', 'convnext_tiny_hnf', 'convnext_xlarge', 'convnext_xxlarge', 'convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny', 'darknet17', 'darknet21', 'darknet53', 'darknetaa53', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenet264d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'eca_botnext26ts_256', 'eca_resnet33ts', 'eca_resnext26ts', 'eca_vovnet39b', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'efficientnet_b0', 'efficientnet_b0_g8_gn', 'efficientnet_b0_g16_evos', 'efficientnet_b0_gn', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b2a', 'efficientnet_b3', 'efficientnet_b3_gn', 'efficientnet_b3_pruned', 'efficientnet_b3a', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_l2', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'efficientnetv2_l', 'efficientnetv2_m', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'efficientnetv2_rw_t', 'efficientnetv2_s', 'efficientnetv2_xl', 'ese_vovnet19b_dw', 'ese_vovnet19b_slim', 'ese_vovnet19b_slim_dw', 'ese_vovnet39b', 'ese_vovnet57b', 'ese_vovnet99b', 'fbnetc_100', 'fbnetv3_b', 'fbnetv3_d', 'fbnetv3_g', 'gc_efficientnetv2_rw_t', 'gcresnet33ts'] ``` +========Before pruning======== +... + (norm_pre): Identity() + (head): NormMlpClassifierHead( + (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity()) + (norm): LayerNorm2d((3072,), eps=1e-05, elementwise_affine=True) + (flatten): Flatten(start_dim=1, end_dim=-1) + (pre_logits): Identity() + (drop): Dropout(p=0.0, inplace=False) + (fc): Linear(in_features=3072, out_features=1000, bias=True) + ) +) -Unprunable: 175 models, -``` - ['bat_resnext26ts', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224', 'coatnext_nano_rw_224', 'convformer_b36', 'convformer_m36', 'convformer_s18', 'convformer_s36', 'convit_base', 'convit_small', 'convit_tiny', 'crossvit_9_240', 'crossvit_9_dagger_240', 'crossvit_15_240', 'crossvit_15_dagger_240', 'crossvit_15_dagger_408', 'crossvit_18_240', 'crossvit_18_dagger_240', 'crossvit_18_dagger_408', 'crossvit_base_240', 'crossvit_small_240', 'crossvit_tiny_240', 'cs3darknet_focus_l', 'cs3darknet_focus_m', 'cs3darknet_focus_s', 'cs3darknet_focus_x', 'cs3darknet_l', 'cs3darknet_m', 'cs3darknet_s', 'cs3darknet_x', 'cs3edgenet_x', 'cs3se_edgenet_x', 'cs3sedarknet_l', 'cs3sedarknet_x', 'cs3sedarknet_xdw', 'cspdarknet53', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'davit_base', 'davit_giant', 'davit_huge', 'davit_large', 'davit_small', 'davit_tiny', 'deit3_base_patch16_224', 'deit3_base_patch16_384', 'deit3_huge_patch14_224', 'deit3_large_patch16_224', 'deit3_large_patch16_384', 'deit3_medium_patch16_224', 'deit3_small_patch16_224', 'deit3_small_patch16_384', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenetblur121d', 'dla60_res2net', 'dla60_res2next', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn48b', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_halonext26ts', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3', 'edgenext_base', 'edgenext_small', 'edgenext_small_rw', 'edgenext_x_small', 'edgenext_xx_small', 'efficientformer_l1', 'efficientformer_l3', 'efficientformer_l7', 'efficientformerv2_l', 'efficientformerv2_s0', 'efficientformerv2_s1', 'efficientformerv2_s2', 'efficientnet_b3_g8_gn', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'ese_vovnet39b_evos', 'eva02_base_patch14_224', 'eva02_base_patch14_448', 'eva02_base_patch16_clip_224', 'eva02_enormous_patch14_clip_224', 'eva02_large_patch14_224', 'eva02_large_patch14_448', 'eva02_large_patch14_clip_224', 'eva02_large_patch14_clip_336', 'eva02_small_patch14_224', 'eva02_small_patch14_336', 'eva02_tiny_patch14_224', 'eva02_tiny_patch14_336', 'eva_giant_patch14_224', 'eva_giant_patch14_336', 'eva_giant_patch14_560', 'eva_giant_patch14_clip_224', 'eva_large_patch14_196', 'eva_large_patch14_336', 'flexivit_base', 'flexivit_large', 'flexivit_small', 'focalnet_base_lrf', 'focalnet_base_srf', 'focalnet_huge_fl3', 'focalnet_huge_fl4', 'focalnet_large_fl3', 'focalnet_large_fl4', 'focalnet_small_lrf', 'focalnet_small_srf', 'focalnet_tiny_lrf', 'focalnet_tiny_srf', 'focalnet_xlarge_fl3', 'focalnet_xlarge_fl4'] + +========After pruning======== +... + (norm_pre): Identity() + (head): NormMlpClassifierHead( + (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity()) + (norm): LayerNorm2d((1536,), eps=1e-05, elementwise_affine=True) + (flatten): Flatten(start_dim=1, end_dim=-1) + (pre_logits): Identity() + (drop): Dropout(p=0.0, inplace=False) + (fc): Linear(in_features=1536, out_features=1000, bias=True) + ) +) +MACs: 197.9920 G => 49.7716 G +Params: 846.4710 M => 213.2587 M ``` diff --git a/examples/timm_models/timm_beit.py b/examples/timm_models/timm_beit.py deleted file mode 100644 index b6d063c6..00000000 --- a/examples/timm_models/timm_beit.py +++ /dev/null @@ -1,123 +0,0 @@ -import os, sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) - -import torch -import torch.nn as nn -import timm -import torch_pruning as tp -from typing import Sequence - -from timm.models.vision_transformer import Attention -import torch.nn.functional as F - -def forward(self, x, shared_rel_pos_bias = None): - B, N, C = x.shape - - qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim - - if self.fused_attn: - rel_pos_bias = None - if self.relative_position_bias_table is not None: - rel_pos_bias = self._get_rel_pos_bias() - if shared_rel_pos_bias is not None: - rel_pos_bias = rel_pos_bias + shared_rel_pos_bias - elif shared_rel_pos_bias is not None: - rel_pos_bias = shared_rel_pos_bias - - x = F.scaled_dot_product_attention( - q, k, v, - attn_mask=rel_pos_bias, - dropout_p=self.attn_drop.p, - ) - else: - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - if self.relative_position_bias_table is not None: - attn = attn + self._get_rel_pos_bias() - if shared_rel_pos_bias is not None: - attn = attn + shared_rel_pos_bias - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - x = x.transpose(1, 2).reshape(B, N, -1) - #x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - -# timm==0.9.2 -# torch==1.12.1 - -timm_models = timm.list_models() -example_inputs = torch.randn(1,3,224,224) -imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") -prunable_list = [] -unprunable_list = [] -problem_with_input_shape = [] - -for i, model_name in enumerate(timm_models): - if not model_name=='beit_base_patch16_224': - continue - - print("Pruning %s..."%model_name) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them. - # unprunable_list.append(model_name) - # continue - try: - model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device) - except: # out of memory error - model = timm.create_model(model_name, pretrained=False, no_jit=True).eval() - device = 'cpu' - ch_groups = {} - for m in model.modules(): - if isinstance(m, timm.models.vision_transformer.Attention): - m.forward = timm_attention_forward.__get__(m, Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - ch_groups[m.qkv] = m.num_heads * 3 - - input_size = model.default_cfg['input_size'] - example_inputs = torch.randn(1, *input_size).to(device) - test_output = model(example_inputs) - - print(model) - prunable = True - #try: - if True: - base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) - pruner = tp.pruner.MagnitudePruner( - model, - example_inputs, - global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. - importance=imp, # importance criterion for parameter selection - iterative_steps=1, # the number of iterations to achieve target sparsity - ch_sparsity=0.5, - ignored_layers=[model.head], - channel_groups=ch_groups, - ) - for g in pruner.step(interactive=True): - #print(g) - g.prune() - - # Modify the attention head size and all head size aftering pruning - for m in model.modules(): - if isinstance(m, timm.models.vision_transformer.Attention): - m.head_dim = m.qkv.out_features // (3 * m.num_heads) - - print(model) - test_output = model(example_inputs) - pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) - print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs)) - print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params)) - - if prunable: - prunable_list.append(model_name) - else: - unprunable_list.append(model_name) - - print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list)) - print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list)) \ No newline at end of file diff --git a/examples/timm_models/timm_pruning.py b/examples/timm_models/timm_pruning.py deleted file mode 100644 index f785b392..00000000 --- a/examples/timm_models/timm_pruning.py +++ /dev/null @@ -1,64 +0,0 @@ -import os, sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) - -import torch -import timm -import torch_pruning as tp - -# timm==0.9.2 -# torch==1.12.1 - -timm_models = timm.list_models() -print(timm_models) -example_inputs = torch.randn(1,3,224,224) -imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") -prunable_list = [] -unprunable_list = [] -problem_with_input_shape = [] -for i, model_name in enumerate(timm_models): - print("Pruning %s..."%model_name) - if "botnet" in model_name or "coatnet" in model_name or "coatnext" in model_name: - unprunable_list.append(model_name) - continue - device = 'cuda' if torch.cuda.is_available() else 'cpu' - #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them. - # unprunable_list.append(model_name) - # continue - try: - model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device) - except: # out of memory error - model = timm.create_model(model_name, pretrained=False, no_jit=True).eval() - device = 'cpu' - - input_size = model.default_cfg['input_size'] - example_inputs = torch.randn(1, *input_size).to(device) - test_output = model(example_inputs) - - print(model) - prunable = True - try: - base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) - pruner = tp.pruner.MagnitudePruner( - model, - example_inputs, - global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. - importance=imp, # importance criterion for parameter selection - iterative_steps=1, # the number of iterations to achieve target sparsity - ch_sparsity=0.5, - ignored_layers=[], - ) - pruner.step() - test_output = model(example_inputs) - pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) - print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs)) - print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params)) - except Exception as e: - prunable = False - - if prunable: - prunable_list.append(model_name) - else: - unprunable_list.append(model_name) - - print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list)) - print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list)) \ No newline at end of file diff --git a/examples/timm_models/timm_vit.py b/examples/timm_models/timm_vit.py deleted file mode 100644 index 3f0ff4ad..00000000 --- a/examples/timm_models/timm_vit.py +++ /dev/null @@ -1,107 +0,0 @@ -import os, sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) - -import torch -import torch.nn as nn -import timm -import torch_pruning as tp -from typing import Sequence - -from timm.models.vision_transformer import Attention -import torch.nn.functional as F - -def timm_attention_forward(self, x): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) - - if self.fused_attn: - x = F.scaled_dot_product_attention( - q, k, v, - dropout_p=self.attn_drop.p, - ) - else: - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - - #x = x.transpose(1, 2).reshape(B, N, C) # this line forces the input and output channels to be identical. - x = x.transpose(0, 1).reshape(B, N, -1) - x = self.proj(x) - x = self.proj_drop(x) - return x - -# timm==0.9.2 -# torch==1.12.1 - -timm_models = timm.list_models() -example_inputs = torch.randn(1,3,224,224) -imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") -prunable_list = [] -unprunable_list = [] -problem_with_input_shape = [] - -for i, model_name in enumerate(timm_models): - if not model_name=='vit_base_patch8_224': - continue - - print("Pruning %s..."%model_name) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them. - # unprunable_list.append(model_name) - # continue - try: - model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device) - except: # out of memory error - model = timm.create_model(model_name, pretrained=False, no_jit=True).eval() - device = 'cpu' - ch_groups = {} - for m in model.modules(): - if isinstance(m, timm.models.vision_transformer.Attention): - m.forward = timm_attention_forward.__get__(m, Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - ch_groups[m.qkv] = m.num_heads * 3 - - input_size = model.default_cfg['input_size'] - example_inputs = torch.randn(1, *input_size).to(device) - test_output = model(example_inputs) - - print(model) - prunable = True - #try: - if True: - base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) - pruner = tp.pruner.MagnitudePruner( - model, - example_inputs, - global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. - importance=imp, # importance criterion for parameter selection - iterative_steps=1, # the number of iterations to achieve target sparsity - ch_sparsity=0.5, - ignored_layers=[model.head], - channel_groups=ch_groups, - ) - for g in pruner.step(interactive=True): - #print(g) - g.prune() - - # Modify the attention head size and all head size aftering pruning - for m in model.modules(): - if isinstance(m, timm.models.vision_transformer.Attention): - m.head_dim = m.qkv.out_features // (3 * m.num_heads) - - print(model) - test_output = model(example_inputs) - pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) - print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs)) - print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params)) - - if prunable: - prunable_list.append(model_name) - else: - unprunable_list.append(model_name) - - print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list)) - print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list)) \ No newline at end of file diff --git a/examples/torchvision_models/readme.md b/examples/torchvision_models/readme.md index d198cd95..6d0231a7 100644 --- a/examples/torchvision_models/readme.md +++ b/examples/torchvision_models/readme.md @@ -19,14 +19,14 @@ python torchvision_pruning.py #### Outputs: ``` -Successful Pruning: 81 Models - ['ssdlite320_mobilenet_v3_large', 'ssd300_vgg16', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fcos_resnet50_fpn', 'keypointrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'retinanet_resnet50_fpn_v2', 'alexnet', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'] +Successful Pruning: 77 Models + ['ssdlite320_mobilenet_v3_large', 'ssd300_vgg16', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fcos_resnet50_fpn', 'keypointrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'retinanet_resnet50_fpn_v2', 'alexnet', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn'] ``` ``` -Unsuccessful Pruning: 4 Models - ['raft_large', 'swin_t', 'swin_s', 'swin_b'] +Unsuccessful Pruning: 8 Models + ['raft_large', 'swin_t', 'swin_s', 'swin_b', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'] ``` #### Vision Transfomer Example diff --git a/examples/transformers/test_latency.py b/examples/transformers/measure_latency.py similarity index 100% rename from examples/transformers/test_latency.py rename to examples/transformers/measure_latency.py diff --git a/examples/transformers/prune_hf_bert.py b/examples/transformers/prune_hf_bert.py index 69546850..635fd5cb 100644 --- a/examples/transformers/prune_hf_bert.py +++ b/examples/transformers/prune_hf_bert.py @@ -14,14 +14,14 @@ imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) -channel_groups = {} +num_heads = {} # All heads should be pruned simultaneously, so we group channels by head. for m in model.modules(): if isinstance(m, BertSelfAttention): - channel_groups[m.query] = m.num_attention_heads - channel_groups[m.key] = m.num_attention_heads - channel_groups[m.value] = m.num_attention_heads + num_heads[m.query] = m.num_attention_heads + num_heads[m.key] = m.num_attention_heads + num_heads[m.value] = m.num_attention_heads pruner = tp.pruner.MetaPruner( model, @@ -30,7 +30,7 @@ importance=imp, # importance criterion for parameter selection iterative_steps=1, # the number of iterations to achieve target sparsity ch_sparsity=0.5, - channel_groups=channel_groups, + num_heads=num_heads, output_transform=lambda out: out.pooler_output.sum(), ignored_layers=[model.pooler], ) diff --git a/examples/transformers/prune_hf_swin.py b/examples/transformers/prune_hf_swin.py index 07b68d2d..45c4eb2d 100644 --- a/examples/transformers/prune_hf_swin.py +++ b/examples/transformers/prune_hf_swin.py @@ -48,15 +48,15 @@ def get_in_channels(self, layer): print(model) imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) -channel_groups = {} +num_heads = {} ignored_layers = [model.classifier] # All heads should be pruned simultaneously, so we group channels by head. for m in model.modules(): if isinstance(m, SwinSelfAttention): - channel_groups[m.query] = m.num_attention_heads - channel_groups[m.key] = m.num_attention_heads - channel_groups[m.value] = m.num_attention_heads + num_heads[m.query] = m.num_attention_heads + num_heads[m.key] = m.num_attention_heads + num_heads[m.value] = m.num_attention_heads pruner = tp.pruner.MetaPruner( model, @@ -65,7 +65,7 @@ def get_in_channels(self, layer): importance=imp, # importance criterion for parameter selection iterative_steps=1, # the number of iterations to achieve target sparsity ch_sparsity=0.5, - channel_groups=channel_groups, + num_heads=num_heads, output_transform=lambda out: out.logits.sum(), ignored_layers=ignored_layers, customized_pruners={SwinPatchMerging: SwinPatchMergingPruner()}, diff --git a/examples/transformers/prune_hf_vit.py b/examples/transformers/prune_hf_vit.py index bc79b481..27256239 100644 --- a/examples/transformers/prune_hf_vit.py +++ b/examples/transformers/prune_hf_vit.py @@ -14,7 +14,7 @@ parser = argparse.ArgumentParser(description='ViT Pruning') parser.add_argument('--model_name', default='google/vit-base-patch16-224', type=str, help='model name') -parser.add_argument('--data_path', default='~/Datasets/shared/imagenet/', type=str, help='model name') +parser.add_argument('--data_path', default='data/imagenet', type=str, help='model name') parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion') parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio') parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform') @@ -96,14 +96,14 @@ def validate_model(model, val_loader, device): print("Accuracy: %.4f, Loss: %.4f"%(acc_ori, loss_ori)) print("Pruning %s..."%args.model_name) -channel_groups = {} +num_heads = {} ignored_layers = [model.classifier] # All heads should be pruned simultaneously, so we group channels by head. for m in model.modules(): if isinstance(m, ViTSelfAttention): - channel_groups[m.query] = m.num_attention_heads - channel_groups[m.key] = m.num_attention_heads - channel_groups[m.value] = m.num_attention_heads + num_heads[m.query] = m.num_attention_heads + num_heads[m.key] = m.num_attention_heads + num_heads[m.value] = m.num_attention_heads if args.bottleneck and isinstance(m, ViTSelfOutput): ignored_layers.append(m.dense) @@ -115,7 +115,7 @@ def validate_model(model, val_loader, device): ch_sparsity=args.pruning_ratio, # target sparsity ignored_layers=ignored_layers, output_transform=lambda out: out.logits.sum(), - channel_groups=channel_groups, + num_heads=num_heads, ) if isinstance(imp, tp.importance.TaylorImportance): diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py index f1d8f654..9b46e838 100644 --- a/examples/transformers/prune_timm_vit.py +++ b/examples/transformers/prune_timm_vit.py @@ -16,7 +16,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Timm ViT Pruning') parser.add_argument('--model_name', default='vit_base_patch16_224', type=str, help='model name') - parser.add_argument('--data_path', default='~/Datasets/shared/imagenet/', type=str, help='model name') + parser.add_argument('--data_path', default='data/imagenet', type=str, help='model name') parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion') parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio') parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform') @@ -33,6 +33,7 @@ def parse_args(): # Here we re-implement the forward function of timm.models.vision_transformer.Attention # as the original forward function requires the input and output channels to be identical. def forward(self, x): + """https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/vision_transformer.py#L79""" B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) @@ -50,7 +51,7 @@ def forward(self, x): attn = self.attn_drop(attn) x = attn @ v - x = x.transpose(1, 2).reshape(B, N, -1) + x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -122,12 +123,12 @@ def main(): base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) print("Pruning %s..."%args.model_name) - ch_groups = {} + num_heads = {} ignored_layers = [model.head] for m in model.modules(): if isinstance(m, timm.models.vision_transformer.Attention): m.forward = forward.__get__(m, timm.models.vision_transformer.Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module - ch_groups[m.qkv] = m.num_heads * 3 + num_heads[m.qkv] = m.num_heads if args.bottleneck and isinstance(m, timm.models.vision_transformer.Mlp): ignored_layers.append(m.fc2) # only prune the internal layers of FFN & Attention @@ -143,7 +144,7 @@ def main(): importance=imp, # importance criterion for parameter selection ch_sparsity=args.pruning_ratio, # target sparsity ignored_layers=ignored_layers, - channel_groups=ch_groups, + num_heads=num_heads, # number of heads in self attention round_to=16, ) if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)): diff --git a/examples/transformers/readme.md b/examples/transformers/readme.md index 0be8c156..a059ffce 100644 --- a/examples/transformers/readme.md +++ b/examples/transformers/readme.md @@ -5,7 +5,7 @@ ### Data Please prepare the ImageNet-1K dataset as follows and modify the data root in the script. ``` -imagenet/ +./data/imagenet/ train/ n01440764/ n01440764_10026.JPEG @@ -90,7 +90,7 @@ wget https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pr * Measure the latency of the pruned models ```bash -python estimate_latency.py --model pretrained/vit_b_16_pruning_taylor_uniform.pth +python measure_latency.py --model pretrained/vit_b_16_pruning_taylor_uniform.pth ``` ## Pruning ViT-ImageNet-1K from [HF Transformers](https://huggingface.co/docs/transformers/index) diff --git a/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh index cf746e75..e0ca0127 100644 --- a/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh +++ b/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh @@ -16,6 +16,6 @@ torchrun --nproc_per_node=8 finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/hf_vit_b_16_pruning_l1_uniform \ --is_huggingface \ \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh index bc192e22..edb2440b 100644 --- a/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh +++ b/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh @@ -16,6 +16,6 @@ torchrun --nproc_per_node=8 finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/hf_vit_b_16_pruning_taylor_uniform \ --is_huggingface \ \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh b/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh index bf5dd16d..4358f6f7 100644 --- a/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh +++ b/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh @@ -17,6 +17,6 @@ torchrun --nproc_per_node=8 finetune.py \ --ra-sampler \ --random-erase 0.25 \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/deit_b_16_pruning_taylor_uniform \ --use_imagenet_mean_std \ \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh index 779f14bb..50012b2c 100644 --- a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh +++ b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh @@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/vit_b_16_pruning_hessian_uniform \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh index f66a0e87..95f5e671 100644 --- a/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh +++ b/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh @@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/vit_b_16_pruning_l1_uniform \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh index 1e764fa2..e233fe7e 100644 --- a/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh +++ b/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh @@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/vit_b_16_pruning_l2_uniform \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh index 0cfb7555..737bdd13 100644 --- a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh +++ b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh @@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/vit_b_16_pruning_taylor_bottleneck \ No newline at end of file diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh index 888085fd..a0fa0675 100644 --- a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh +++ b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh @@ -17,5 +17,5 @@ torchrun --nproc_per_node=8 finetune.py \ --ra-sampler \ --random-erase 0.25 \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --output-dir output/vit_b_16_pruning_taylor_uniform_v2 \ No newline at end of file diff --git a/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh index 8a3e3bd7..4656d0e8 100644 --- a/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh +++ b/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh @@ -3,7 +3,7 @@ python prune_hf_vit.py \ --pruning_type l1 \ --pruning_ratio 0.5 \ --taylor_batchs 10 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --test_accuracy \ --train_batch_size 64 \ --val_batch_size 64 \ diff --git a/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh index 3d606509..6c2ff8e5 100644 --- a/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh +++ b/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh @@ -3,7 +3,7 @@ python prune_hf_vit.py \ --pruning_type taylor \ --pruning_ratio 0.5 \ --taylor_batchs 10 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --test_accuracy \ --train_batch_size 64 \ --val_batch_size 64 \ diff --git a/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh b/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh index d58e94d6..0040e5ad 100644 --- a/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh +++ b/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh @@ -3,7 +3,7 @@ python prune_timm_vit.py \ --pruning_type taylor \ --pruning_ratio 0.54 \ --taylor_batchs 50 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --train_batch_size 64 \ --val_batch_size 64 \ --save_as output/pruned/deit_base_patch16_224_pruned_taylor_uniform.pth \ diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh index 02d9b29f..39e35671 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh @@ -4,7 +4,7 @@ python prune_timm_vit.py \ --pruning_ratio 0.5 \ --taylor_batchs 10 \ --test_accuracy \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --train_batch_size 64 \ --val_batch_size 64 \ --save_as output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth \ \ No newline at end of file diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh index 7cdb17bc..4df87fd5 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh @@ -3,7 +3,7 @@ python prune_timm_vit.py \ --pruning_type l1 \ --pruning_ratio 0.5 \ --taylor_batchs 10 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --test_accuracy \ --train_batch_size 64 \ --val_batch_size 64 \ diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh index cdd51e41..21386e8b 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh @@ -3,7 +3,7 @@ python prune_timm_vit.py \ --pruning_type l2 \ --pruning_ratio 0.5 \ --taylor_batchs 10 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --test_accuracy \ --train_batch_size 64 \ --val_batch_size 64 \ diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh index 149238c2..4ae381d1 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh @@ -3,7 +3,7 @@ python prune_timm_vit.py \ --pruning_type taylor \ --pruning_ratio 0.73 \ --taylor_batchs 10 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --bottleneck \ --train_batch_size 64 \ --val_batch_size 64 \ diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh index 2763a364..19589ce0 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh @@ -3,7 +3,7 @@ python prune_timm_vit.py \ --pruning_type taylor \ --pruning_ratio 0.54 \ --taylor_batchs 50 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --train_batch_size 64 \ --val_batch_size 64 \ --save_as output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth \ \ No newline at end of file diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh index 9173b20b..24672b68 100644 --- a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh +++ b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh @@ -3,7 +3,7 @@ python prune_timm_vit.py \ --pruning_type taylor \ --pruning_ratio 0.6 \ --taylor_batchs 10 \ - --data_path ~/Datasets/shared/imagenet \ + --data_path data/imagenet \ --train_batch_size 64 \ --val_batch_size 64 \ --save_as output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth \ diff --git a/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh b/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh index 353b0d06..56f8838f 100644 --- a/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh +++ b/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh @@ -16,7 +16,7 @@ python finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --test-only \ --interpolation bilinear \ --is_huggingface \ \ No newline at end of file diff --git a/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh b/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh index 9d6e3567..7fe0da2d 100644 --- a/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh +++ b/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh @@ -16,6 +16,6 @@ python finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --test-only \ --use_imagenet_mean_std \ diff --git a/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh b/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh index 08b021ec..62f4b7dd 100644 --- a/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh +++ b/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh @@ -16,5 +16,5 @@ python finetune.py \ --clip-grad-norm 1 \ --ra-sampler \ --cutmix-alpha 1.0 \ - --data-path "~/Datasets/shared/imagenet/" \ + --data-path "data/imagenet" \ --test-only \ diff --git a/tests/test_concat_split.py b/tests/test_concat_split.py index e7811aa7..b59de7c8 100644 --- a/tests/test_concat_split.py +++ b/tests/test_concat_split.py @@ -60,13 +60,13 @@ def test_pruner(): ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} ignored_layers=ignored_layers, ) - print(model) + base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) for i in range(iterative_steps): for g in pruner.step(interactive=True): - print(g.details()) + #print(g.details()) g.prune() - print(model) + print(model) macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) print(model(example_inputs).shape) print( diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index bc791fa4..2caa2fe3 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -506,7 +506,6 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o layer_channels = pruner.get_out_channels(m) group = self.get_pruning_group( m, pruner.prune_out_channels, list(range(layer_channels))) - prunable_group = True for dep, _ in group: module = dep.target.module @@ -852,8 +851,10 @@ def _init_shape_information(self): if node.type == ops.OPTYPE.SPLIT: grad_fn = node.grad_fn - if hasattr(grad_fn, '_saved_self_sizes'): + + if hasattr(grad_fn, '_saved_self_sizes') or hasattr(grad_fn, '_saved_split_sizes'): MAX_LEGAL_DIM = 100 + if hasattr(grad_fn, '_saved_split_sizes') and hasattr(grad_fn, '_saved_dim') : if grad_fn._saved_dim != 1 and grad_fn._saved_dim < MAX_LEGAL_DIM: # a temp fix for pytorch==1.11, where the _saved_dim is an uninitialized value like 118745347895359 continue @@ -1036,6 +1037,7 @@ def _update_split_index_mapping(self, split_node: Node): return offsets = split_node.module.offsets + if offsets is None: return addressed_dep = [] diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index 692f135d..18abdb5c 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -27,11 +27,17 @@ class MetaPruner: * round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None. # Adavanced + * in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict(). + * out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict(). + * num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict(). * customized_pruners (dict): a dict containing module-pruner pairs. Default: None. * unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None. * root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM]. * forward_fn (Callable): A function to execute model.forward. Default: None. * output_transform (Callable): A function to transform network outputs. Default: None. + + # Deprecated + * channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict(). """ def __init__( @@ -52,6 +58,7 @@ def __init__( # Advanced in_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer input out_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer output + num_heads: typing.Dict[nn.Module, int] = dict(), # The number of heads for multi-head attention customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner} unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0} root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group @@ -59,7 +66,7 @@ def __init__( output_transform: typing.Callable = None, # a function to transform network outputs # deprecated - channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers + channel_groups: typing.Dict[nn.Module, int] = dict(), # channel grouping ): self.model = model self.importance = importance @@ -72,9 +79,12 @@ def __init__( warnings.warn("channel_groups is deprecated. Please use in_channel_groups and out_channel_groups instead.") out_channel_groups.update(channel_groups) + if len(num_heads) > 0: + out_channel_groups.update(num_heads) + self.in_channel_groups = in_channel_groups self.out_channel_groups = out_channel_groups - + self.num_heads = num_heads self.root_module_types = root_module_types self.round_to = round_to @@ -129,7 +139,7 @@ def __init__( in_ch_group = layer_pruner.get_in_channel_groups(m) out_ch_group = layer_pruner.get_out_channel_groups(m) if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels: - continue + continue if in_ch_group > 1: self.in_channel_groups[m] = in_ch_group if out_ch_group > 1: @@ -149,7 +159,7 @@ def __init__( if self.global_pruning: initial_total_channels = 0 for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): - group = self._downstream_node_as_root_if_unbind(group) + group = self._downstream_node_as_root_if_attention(group) initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) ) self.initial_total_channels = initial_total_channels @@ -209,8 +219,8 @@ def _check_sparsity(self, group) -> bool: def _get_channel_groups(self, group) -> int: ch_groups = 1 - has_unbind = False - unbind_node = None + #has_unbind = False + #unbind_node = None for dep, _ in group: module = dep.target.module @@ -220,24 +230,24 @@ def _get_channel_groups(self, group) -> int: if module in channel_groups: ch_groups = channel_groups[module] - if dep.source.type==ops.OPTYPE.UNBIND: - has_unbind = True - unbind_node = dep.source + #if dep.source.type==ops.OPTYPE.UNBIND: + # has_unbind = True + # unbind_node = dep.source - if has_unbind and ch_groups>1: - ch_groups = ch_groups // len(unbind_node.outputs) + #if has_unbind and ch_groups>1: + # ch_groups = ch_groups // len(unbind_node.outputs) return ch_groups # no channel grouping - def _downstream_node_as_root_if_unbind(self, group): + def _downstream_node_as_root_if_attention(self, group): # Use a downstream node as the root if torch.unbind exists. TODO: find a general way to handle torch.unbind in timm - qkv_unbind = False + is_attention = False downstream_dep = None for _dep, _idxs in group: - if _dep.source.type == ops.OPTYPE.UNBIND: - qkv_unbind = True + if _dep.source.module in self.num_heads: + is_attention = True if isinstance(_dep.target.module, tuple(self.root_module_types)): downstream_dep = _dep - if qkv_unbind and downstream_dep is not None: # use a downstream node as the root node + if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs) return group @@ -254,7 +264,7 @@ def prune_local(self) -> typing.Generator: for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): if self._check_sparsity(group): # check pruning ratio - group = self._downstream_node_as_root_if_unbind(group) + group = self._downstream_node_as_root_if_attention(group) module = group[0][0].target.module pruning_fn = group[0][0].handler @@ -302,7 +312,6 @@ def prune_local(self) -> typing.Generator: group = self.DG.get_pruning_group( module, pruning_fn, pruning_idxs.tolist()) - if self.DG.check_pruning_group(group): yield group @@ -315,7 +324,7 @@ def prune_global(self) -> typing.Generator: global_importance = [] for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): if self._check_sparsity(group): - group = self._downstream_node_as_root_if_unbind(group) + group = self._downstream_node_as_root_if_attention(group) ch_groups = self._get_channel_groups(group) imp = self.estimate_importance(group, ch_groups=ch_groups) if imp is None: continue diff --git a/torch_pruning/pruner/function.py b/torch_pruning/pruner/function.py index 394791c8..eafa4efb 100644 --- a/torch_pruning/pruner/function.py +++ b/torch_pruning/pruner/function.py @@ -246,7 +246,8 @@ def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module keep_idxs.sort() if layer.elementwise_affine: layer.weight = self._prune_parameter_and_grad(layer.weight, keep_idxs, pruning_dim) - layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim) + if layer.bias is not None: + layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim) if pruning_dim != -1: layer.normalized_shape = layer.normalized_shape[:pruning_dim] + ( keep_idxs.size(0), ) + layer.normalized_shape[pruning_dim+1:]