From 45bcf3324c2f61b7aecca47f4dd9647ea7a4fcbc Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 16:27:13 +0800
Subject: [PATCH 01/17] Fixed a bug: pruning with layernorm.bias=None
---
torch_pruning/pruner/function.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/torch_pruning/pruner/function.py b/torch_pruning/pruner/function.py
index 394791c8..eafa4efb 100644
--- a/torch_pruning/pruner/function.py
+++ b/torch_pruning/pruner/function.py
@@ -246,7 +246,8 @@ def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module
keep_idxs.sort()
if layer.elementwise_affine:
layer.weight = self._prune_parameter_and_grad(layer.weight, keep_idxs, pruning_dim)
- layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim)
+ if layer.bias is not None:
+ layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim)
if pruning_dim != -1:
layer.normalized_shape = layer.normalized_shape[:pruning_dim] + (
keep_idxs.size(0), ) + layer.normalized_shape[pruning_dim+1:]
From 8f7658fea28b38f9efcecda063964c70fce8f45d Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 16:27:51 +0800
Subject: [PATCH 02/17] Update timm examples
---
examples/timm_models/prune_timm_models.py | 66 +++++++++++++
examples/timm_models/readme.md | 49 +++++++---
examples/timm_models/timm_vit.py | 107 ----------------------
3 files changed, 101 insertions(+), 121 deletions(-)
create mode 100644 examples/timm_models/prune_timm_models.py
delete mode 100644 examples/timm_models/timm_vit.py
diff --git a/examples/timm_models/prune_timm_models.py b/examples/timm_models/prune_timm_models.py
new file mode 100644
index 00000000..6c706edf
--- /dev/null
+++ b/examples/timm_models/prune_timm_models.py
@@ -0,0 +1,66 @@
+import os, sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Sequence
+import timm
+from timm.models.vision_transformer import Attention
+import torch_pruning as tp
+import argparse
+
+parser = argparse.ArgumentParser(description='Prune timm models')
+parser.add_argument('--model', default=None, type=str, help='model name')
+parser.add_argument('--ch_sparsity', default=0.5, type=float, help='channel sparsity')
+parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
+parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
+parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
+args = parser.parse_args()
+
+def main():
+ timm_models = timm.list_models()
+ if args.list_models:
+ print(timm_models)
+ if args.model is None:
+ return
+ assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)
+
+
+ imp = tp.importance.GroupNormImportance()
+ print("Pruning %s..."%args.model)
+
+ input_size = model.default_cfg['input_size']
+ example_inputs = torch.randn(1, *input_size).to(device)
+ test_output = model(example_inputs)
+ ignored_layers = []
+ for m in model.modules():
+ if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
+ ignored_layers.append(m)
+ print("Ignore classifier layer: ", m)
+
+ print("========Before pruning========")
+ print(model)
+ base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
+ pruner = tp.pruner.MagnitudePruner(
+ model,
+ example_inputs,
+ global_pruning=args.global_pruning, # If False, a uniform sparsity will be assigned to different layers.
+ importance=imp, # importance criterion for parameter selection
+ iterative_steps=1, # the number of iterations to achieve target sparsity
+ ch_sparsity=args.ch_sparsity, # target sparsity
+ ignored_layers=ignored_layers,
+ )
+ for g in pruner.step(interactive=True):
+ g.prune()
+
+ print("========After pruning========")
+ print(model)
+ test_output = model(example_inputs)
+ pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
+ print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
+ print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))
+
+if __name__=='__main__':
+ main()
\ No newline at end of file
diff --git a/examples/timm_models/readme.md b/examples/timm_models/readme.md
index 3b262412..2f146ab5 100644
--- a/examples/timm_models/readme.md
+++ b/examples/timm_models/readme.md
@@ -1,29 +1,50 @@
# Pruning Models from Timm
-## 0. Requirements
+
+## 0. List all models in Timm
```bash
-pip install -r requirements.txt
+python prune_timm_models.py --list_models
```
-Tested environment:
+
+Output:
```
-pytorch==1.12.1
-timm=0.9.2
+['bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', ...]
```
## 1. Pruning
-
-```python
-python timm_pruning.py
+```bash
+python prune_timm_models.py --model convnext_xxlarge --ch_sparsity 0.5 # --global_pruning
```
#### Outputs:
-Prunable: 119 models,
-```
- ['beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_femto', 'convnext_femto_ols', 'convnext_large', 'convnext_large_mlp', 'convnext_nano', 'convnext_nano_ols', 'convnext_pico', 'convnext_pico_ols', 'convnext_small', 'convnext_tiny', 'convnext_tiny_hnf', 'convnext_xlarge', 'convnext_xxlarge', 'convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny', 'darknet17', 'darknet21', 'darknet53', 'darknetaa53', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenet264d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'eca_botnext26ts_256', 'eca_resnet33ts', 'eca_resnext26ts', 'eca_vovnet39b', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'efficientnet_b0', 'efficientnet_b0_g8_gn', 'efficientnet_b0_g16_evos', 'efficientnet_b0_gn', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b2a', 'efficientnet_b3', 'efficientnet_b3_gn', 'efficientnet_b3_pruned', 'efficientnet_b3a', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_l2', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'efficientnetv2_l', 'efficientnetv2_m', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'efficientnetv2_rw_t', 'efficientnetv2_s', 'efficientnetv2_xl', 'ese_vovnet19b_dw', 'ese_vovnet19b_slim', 'ese_vovnet19b_slim_dw', 'ese_vovnet39b', 'ese_vovnet57b', 'ese_vovnet99b', 'fbnetc_100', 'fbnetv3_b', 'fbnetv3_d', 'fbnetv3_g', 'gc_efficientnetv2_rw_t', 'gcresnet33ts']
```
+========Before pruning========
+...
+ (norm_pre): Identity()
+ (head): NormMlpClassifierHead(
+ (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
+ (norm): LayerNorm2d((3072,), eps=1e-05, elementwise_affine=True)
+ (flatten): Flatten(start_dim=1, end_dim=-1)
+ (pre_logits): Identity()
+ (drop): Dropout(p=0.0, inplace=False)
+ (fc): Linear(in_features=3072, out_features=1000, bias=True)
+ )
+)
-Unprunable: 175 models,
-```
- ['bat_resnext26ts', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224', 'coatnext_nano_rw_224', 'convformer_b36', 'convformer_m36', 'convformer_s18', 'convformer_s36', 'convit_base', 'convit_small', 'convit_tiny', 'crossvit_9_240', 'crossvit_9_dagger_240', 'crossvit_15_240', 'crossvit_15_dagger_240', 'crossvit_15_dagger_408', 'crossvit_18_240', 'crossvit_18_dagger_240', 'crossvit_18_dagger_408', 'crossvit_base_240', 'crossvit_small_240', 'crossvit_tiny_240', 'cs3darknet_focus_l', 'cs3darknet_focus_m', 'cs3darknet_focus_s', 'cs3darknet_focus_x', 'cs3darknet_l', 'cs3darknet_m', 'cs3darknet_s', 'cs3darknet_x', 'cs3edgenet_x', 'cs3se_edgenet_x', 'cs3sedarknet_l', 'cs3sedarknet_x', 'cs3sedarknet_xdw', 'cspdarknet53', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'davit_base', 'davit_giant', 'davit_huge', 'davit_large', 'davit_small', 'davit_tiny', 'deit3_base_patch16_224', 'deit3_base_patch16_384', 'deit3_huge_patch14_224', 'deit3_large_patch16_224', 'deit3_large_patch16_384', 'deit3_medium_patch16_224', 'deit3_small_patch16_224', 'deit3_small_patch16_384', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenetblur121d', 'dla60_res2net', 'dla60_res2next', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn48b', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_halonext26ts', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3', 'edgenext_base', 'edgenext_small', 'edgenext_small_rw', 'edgenext_x_small', 'edgenext_xx_small', 'efficientformer_l1', 'efficientformer_l3', 'efficientformer_l7', 'efficientformerv2_l', 'efficientformerv2_s0', 'efficientformerv2_s1', 'efficientformerv2_s2', 'efficientnet_b3_g8_gn', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'ese_vovnet39b_evos', 'eva02_base_patch14_224', 'eva02_base_patch14_448', 'eva02_base_patch16_clip_224', 'eva02_enormous_patch14_clip_224', 'eva02_large_patch14_224', 'eva02_large_patch14_448', 'eva02_large_patch14_clip_224', 'eva02_large_patch14_clip_336', 'eva02_small_patch14_224', 'eva02_small_patch14_336', 'eva02_tiny_patch14_224', 'eva02_tiny_patch14_336', 'eva_giant_patch14_224', 'eva_giant_patch14_336', 'eva_giant_patch14_560', 'eva_giant_patch14_clip_224', 'eva_large_patch14_196', 'eva_large_patch14_336', 'flexivit_base', 'flexivit_large', 'flexivit_small', 'focalnet_base_lrf', 'focalnet_base_srf', 'focalnet_huge_fl3', 'focalnet_huge_fl4', 'focalnet_large_fl3', 'focalnet_large_fl4', 'focalnet_small_lrf', 'focalnet_small_srf', 'focalnet_tiny_lrf', 'focalnet_tiny_srf', 'focalnet_xlarge_fl3', 'focalnet_xlarge_fl4']
+
+========After pruning========
+...
+ (norm_pre): Identity()
+ (head): NormMlpClassifierHead(
+ (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
+ (norm): LayerNorm2d((1536,), eps=1e-05, elementwise_affine=True)
+ (flatten): Flatten(start_dim=1, end_dim=-1)
+ (pre_logits): Identity()
+ (drop): Dropout(p=0.0, inplace=False)
+ (fc): Linear(in_features=1536, out_features=1000, bias=True)
+ )
+)
+MACs: 197.9920 G => 49.7716 G
+Params: 846.4710 M => 213.2587 M
```
diff --git a/examples/timm_models/timm_vit.py b/examples/timm_models/timm_vit.py
deleted file mode 100644
index 3f0ff4ad..00000000
--- a/examples/timm_models/timm_vit.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import os, sys
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
-
-import torch
-import torch.nn as nn
-import timm
-import torch_pruning as tp
-from typing import Sequence
-
-from timm.models.vision_transformer import Attention
-import torch.nn.functional as F
-
-def timm_attention_forward(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
-
- if self.fused_attn:
- x = F.scaled_dot_product_attention(
- q, k, v,
- dropout_p=self.attn_drop.p,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
-
- #x = x.transpose(1, 2).reshape(B, N, C) # this line forces the input and output channels to be identical.
- x = x.transpose(0, 1).reshape(B, N, -1)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-# timm==0.9.2
-# torch==1.12.1
-
-timm_models = timm.list_models()
-example_inputs = torch.randn(1,3,224,224)
-imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
-prunable_list = []
-unprunable_list = []
-problem_with_input_shape = []
-
-for i, model_name in enumerate(timm_models):
- if not model_name=='vit_base_patch8_224':
- continue
-
- print("Pruning %s..."%model_name)
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them.
- # unprunable_list.append(model_name)
- # continue
- try:
- model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device)
- except: # out of memory error
- model = timm.create_model(model_name, pretrained=False, no_jit=True).eval()
- device = 'cpu'
- ch_groups = {}
- for m in model.modules():
- if isinstance(m, timm.models.vision_transformer.Attention):
- m.forward = timm_attention_forward.__get__(m, Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
- ch_groups[m.qkv] = m.num_heads * 3
-
- input_size = model.default_cfg['input_size']
- example_inputs = torch.randn(1, *input_size).to(device)
- test_output = model(example_inputs)
-
- print(model)
- prunable = True
- #try:
- if True:
- base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
- pruner = tp.pruner.MagnitudePruner(
- model,
- example_inputs,
- global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
- importance=imp, # importance criterion for parameter selection
- iterative_steps=1, # the number of iterations to achieve target sparsity
- ch_sparsity=0.5,
- ignored_layers=[model.head],
- channel_groups=ch_groups,
- )
- for g in pruner.step(interactive=True):
- #print(g)
- g.prune()
-
- # Modify the attention head size and all head size aftering pruning
- for m in model.modules():
- if isinstance(m, timm.models.vision_transformer.Attention):
- m.head_dim = m.qkv.out_features // (3 * m.num_heads)
-
- print(model)
- test_output = model(example_inputs)
- pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
- print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs))
- print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params))
-
- if prunable:
- prunable_list.append(model_name)
- else:
- unprunable_list.append(model_name)
-
- print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list))
- print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list))
\ No newline at end of file
From 4ac69963286f4aaec080def5bf41020393a703a6 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 16:28:56 +0800
Subject: [PATCH 03/17] Update Readme
---
README.md | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
diff --git a/README.md b/README.md
index 51d31ccd..01f1be3c 100644
--- a/README.md
+++ b/README.md
@@ -244,7 +244,25 @@ With DepGraph, it is easy to design some "group-level" criteria to estimate the
+#### Modify model attributes
+In some implementation, model forward might rely on some static attributes. For example in [``convformer_s18``](https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/metaformer.py#L107) of timm, we have:
+
+```python
+class Scale(nn.Module):
+ """
+ Scale vector by element multiplications.
+ """
+
+ def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
+ super().__init__()
+ self.shape = (dim, 1, 1) if use_nchw else (dim,)
+ self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
+
+ def forward(self, x):
+ return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning
+```
+where the ```forward``` function relies on ``self.shape`` during forwarding. But, ``self.shape`` changed after pruning, which should be manually adjusted accordingly.
### 3. Save & Load
From 8983df50eb0b0b28c7c56fbf9edb64956eb21f4f Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 16:29:58 +0800
Subject: [PATCH 04/17] Update Readme
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 01f1be3c..92f94316 100644
--- a/README.md
+++ b/README.md
@@ -262,7 +262,7 @@ class Scale(nn.Module):
def forward(self, x):
return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning
```
-where the ```forward``` function relies on ``self.shape`` during forwarding. But, ``self.shape`` changed after pruning, which should be manually adjusted accordingly.
+where the ```forward``` function relies on ``self.shape`` during forwarding. But, the true ``self.shape`` changed after pruning, which should be manually adjusted accordingly.
### 3. Save & Load
From f5acec3edba480cab3a1a8bbc9c2fef80f528aa3 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 16:31:00 +0800
Subject: [PATCH 05/17] Update Readme
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 92f94316..975c756d 100644
--- a/README.md
+++ b/README.md
@@ -244,7 +244,7 @@ With DepGraph, it is easy to design some "group-level" criteria to estimate the
-#### Modify model attributes
+#### Modify attributes or forward functions
In some implementation, model forward might rely on some static attributes. For example in [``convformer_s18``](https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/metaformer.py#L107) of timm, we have:
From 81ece2b87e8f3fa2061c7ca03e3d9eb3d63485de Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 16:40:59 +0800
Subject: [PATCH 06/17] Remove old scripts
---
examples/timm_models/timm_beit.py | 123 ---------------------------
examples/timm_models/timm_pruning.py | 64 --------------
2 files changed, 187 deletions(-)
delete mode 100644 examples/timm_models/timm_beit.py
delete mode 100644 examples/timm_models/timm_pruning.py
diff --git a/examples/timm_models/timm_beit.py b/examples/timm_models/timm_beit.py
deleted file mode 100644
index b6d063c6..00000000
--- a/examples/timm_models/timm_beit.py
+++ /dev/null
@@ -1,123 +0,0 @@
-import os, sys
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
-
-import torch
-import torch.nn as nn
-import timm
-import torch_pruning as tp
-from typing import Sequence
-
-from timm.models.vision_transformer import Attention
-import torch.nn.functional as F
-
-def forward(self, x, shared_rel_pos_bias = None):
- B, N, C = x.shape
-
- qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
-
- if self.fused_attn:
- rel_pos_bias = None
- if self.relative_position_bias_table is not None:
- rel_pos_bias = self._get_rel_pos_bias()
- if shared_rel_pos_bias is not None:
- rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
- elif shared_rel_pos_bias is not None:
- rel_pos_bias = shared_rel_pos_bias
-
- x = F.scaled_dot_product_attention(
- q, k, v,
- attn_mask=rel_pos_bias,
- dropout_p=self.attn_drop.p,
- )
- else:
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- if self.relative_position_bias_table is not None:
- attn = attn + self._get_rel_pos_bias()
- if shared_rel_pos_bias is not None:
- attn = attn + shared_rel_pos_bias
-
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B, N, -1)
- #x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-# timm==0.9.2
-# torch==1.12.1
-
-timm_models = timm.list_models()
-example_inputs = torch.randn(1,3,224,224)
-imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
-prunable_list = []
-unprunable_list = []
-problem_with_input_shape = []
-
-for i, model_name in enumerate(timm_models):
- if not model_name=='beit_base_patch16_224':
- continue
-
- print("Pruning %s..."%model_name)
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them.
- # unprunable_list.append(model_name)
- # continue
- try:
- model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device)
- except: # out of memory error
- model = timm.create_model(model_name, pretrained=False, no_jit=True).eval()
- device = 'cpu'
- ch_groups = {}
- for m in model.modules():
- if isinstance(m, timm.models.vision_transformer.Attention):
- m.forward = timm_attention_forward.__get__(m, Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
- ch_groups[m.qkv] = m.num_heads * 3
-
- input_size = model.default_cfg['input_size']
- example_inputs = torch.randn(1, *input_size).to(device)
- test_output = model(example_inputs)
-
- print(model)
- prunable = True
- #try:
- if True:
- base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
- pruner = tp.pruner.MagnitudePruner(
- model,
- example_inputs,
- global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
- importance=imp, # importance criterion for parameter selection
- iterative_steps=1, # the number of iterations to achieve target sparsity
- ch_sparsity=0.5,
- ignored_layers=[model.head],
- channel_groups=ch_groups,
- )
- for g in pruner.step(interactive=True):
- #print(g)
- g.prune()
-
- # Modify the attention head size and all head size aftering pruning
- for m in model.modules():
- if isinstance(m, timm.models.vision_transformer.Attention):
- m.head_dim = m.qkv.out_features // (3 * m.num_heads)
-
- print(model)
- test_output = model(example_inputs)
- pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
- print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs))
- print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params))
-
- if prunable:
- prunable_list.append(model_name)
- else:
- unprunable_list.append(model_name)
-
- print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list))
- print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list))
\ No newline at end of file
diff --git a/examples/timm_models/timm_pruning.py b/examples/timm_models/timm_pruning.py
deleted file mode 100644
index f785b392..00000000
--- a/examples/timm_models/timm_pruning.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os, sys
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
-
-import torch
-import timm
-import torch_pruning as tp
-
-# timm==0.9.2
-# torch==1.12.1
-
-timm_models = timm.list_models()
-print(timm_models)
-example_inputs = torch.randn(1,3,224,224)
-imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
-prunable_list = []
-unprunable_list = []
-problem_with_input_shape = []
-for i, model_name in enumerate(timm_models):
- print("Pruning %s..."%model_name)
- if "botnet" in model_name or "coatnet" in model_name or "coatnext" in model_name:
- unprunable_list.append(model_name)
- continue
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them.
- # unprunable_list.append(model_name)
- # continue
- try:
- model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device)
- except: # out of memory error
- model = timm.create_model(model_name, pretrained=False, no_jit=True).eval()
- device = 'cpu'
-
- input_size = model.default_cfg['input_size']
- example_inputs = torch.randn(1, *input_size).to(device)
- test_output = model(example_inputs)
-
- print(model)
- prunable = True
- try:
- base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
- pruner = tp.pruner.MagnitudePruner(
- model,
- example_inputs,
- global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
- importance=imp, # importance criterion for parameter selection
- iterative_steps=1, # the number of iterations to achieve target sparsity
- ch_sparsity=0.5,
- ignored_layers=[],
- )
- pruner.step()
- test_output = model(example_inputs)
- pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
- print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs))
- print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params))
- except Exception as e:
- prunable = False
-
- if prunable:
- prunable_list.append(model_name)
- else:
- unprunable_list.append(model_name)
-
- print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list))
- print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list))
\ No newline at end of file
From b6b8fdb8405cb02ce78b939c963f6a477c8f644c Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 17:14:56 +0800
Subject: [PATCH 07/17] Update readme
---
examples/timm_models/readme.md | 3 +++
1 file changed, 3 insertions(+)
diff --git a/examples/timm_models/readme.md b/examples/timm_models/readme.md
index 2f146ab5..b987a310 100644
--- a/examples/timm_models/readme.md
+++ b/examples/timm_models/readme.md
@@ -13,6 +13,9 @@ Output:
```
## 1. Pruning
+
+Some models might requires additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers](../transformers/) for more details.
+
```bash
python prune_timm_models.py --model convnext_xxlarge --ch_sparsity 0.5 # --global_pruning
```
From d62a3f25983b5f6ca2a989a2eed3d637a823c72a Mon Sep 17 00:00:00 2001
From: Gongfan Fang
Date: Mon, 2 Oct 2023 09:17:20 +0000
Subject: [PATCH 08/17] Update readme.md
---
examples/timm_models/readme.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/timm_models/readme.md b/examples/timm_models/readme.md
index b987a310..41e22708 100644
--- a/examples/timm_models/readme.md
+++ b/examples/timm_models/readme.md
@@ -14,7 +14,7 @@ Output:
## 1. Pruning
-Some models might requires additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers](../transformers/) for more details.
+Some models might require additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers](../transformers/) for more details.
```bash
python prune_timm_models.py --model convnext_xxlarge --ch_sparsity 0.5 # --global_pruning
From 1d12bcab36b13eeea4ac0d8f7861892236279202 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 17:17:54 +0800
Subject: [PATCH 09/17] Update comments
---
examples/transformers/prune_timm_vit.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py
index f1d8f654..3c752c05 100644
--- a/examples/transformers/prune_timm_vit.py
+++ b/examples/transformers/prune_timm_vit.py
@@ -33,6 +33,7 @@ def parse_args():
# Here we re-implement the forward function of timm.models.vision_transformer.Attention
# as the original forward function requires the input and output channels to be identical.
def forward(self, x):
+ """https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/vision_transformer.py#L79"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
@@ -50,7 +51,7 @@ def forward(self, x):
attn = self.attn_drop(attn)
x = attn @ v
- x = x.transpose(1, 2).reshape(B, N, -1)
+ x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
From 9647fa16bc743be403a0fbaed0d8de3299571ef0 Mon Sep 17 00:00:00 2001
From: Gongfan Fang
Date: Mon, 2 Oct 2023 09:17:59 +0000
Subject: [PATCH 10/17] Update readme.md
---
examples/timm_models/readme.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/timm_models/readme.md b/examples/timm_models/readme.md
index 41e22708..41a9b3fb 100644
--- a/examples/timm_models/readme.md
+++ b/examples/timm_models/readme.md
@@ -14,7 +14,7 @@ Output:
## 1. Pruning
-Some models might require additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers](../transformers/) for more details.
+Some models might require additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers/prune_timm_vit.py](../transformers/prune_timm_vit.py) for more details.
```bash
python prune_timm_models.py --model convnext_xxlarge --ch_sparsity 0.5 # --global_pruning
From 82f04be168b310615bd6304d4dcafd0f96e99ea4 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 19:52:39 +0800
Subject: [PATCH 11/17] Add a new argument for Transformer pruning
---
.gitignore | 3 +-
examples/timm_models/prune_timm_models.py | 1 -
examples/transformers/prune_hf_bert.py | 10 ++---
examples/transformers/prune_hf_swin.py | 10 ++---
examples/transformers/prune_hf_vit.py | 10 ++---
examples/transformers/prune_timm_vit.py | 6 +--
.../scripts/prune_timm_vit_b_16_l1_uniform.sh | 3 +-
torch_pruning/pruner/algorithms/metapruner.py | 40 ++++++++++---------
8 files changed, 43 insertions(+), 40 deletions(-)
diff --git a/.gitignore b/.gitignore
index 250657f5..968fb052 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,4 +27,5 @@ torch_pruning_bak
benchmarks/run
output
acc.png
-pretrained
\ No newline at end of file
+pretrained
+data
\ No newline at end of file
diff --git a/examples/timm_models/prune_timm_models.py b/examples/timm_models/prune_timm_models.py
index 6c706edf..6e3c50ea 100644
--- a/examples/timm_models/prune_timm_models.py
+++ b/examples/timm_models/prune_timm_models.py
@@ -27,7 +27,6 @@ def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)
-
imp = tp.importance.GroupNormImportance()
print("Pruning %s..."%args.model)
diff --git a/examples/transformers/prune_hf_bert.py b/examples/transformers/prune_hf_bert.py
index 69546850..635fd5cb 100644
--- a/examples/transformers/prune_hf_bert.py
+++ b/examples/transformers/prune_hf_bert.py
@@ -14,14 +14,14 @@
imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
-channel_groups = {}
+num_heads = {}
# All heads should be pruned simultaneously, so we group channels by head.
for m in model.modules():
if isinstance(m, BertSelfAttention):
- channel_groups[m.query] = m.num_attention_heads
- channel_groups[m.key] = m.num_attention_heads
- channel_groups[m.value] = m.num_attention_heads
+ num_heads[m.query] = m.num_attention_heads
+ num_heads[m.key] = m.num_attention_heads
+ num_heads[m.value] = m.num_attention_heads
pruner = tp.pruner.MetaPruner(
model,
@@ -30,7 +30,7 @@
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target sparsity
ch_sparsity=0.5,
- channel_groups=channel_groups,
+ num_heads=num_heads,
output_transform=lambda out: out.pooler_output.sum(),
ignored_layers=[model.pooler],
)
diff --git a/examples/transformers/prune_hf_swin.py b/examples/transformers/prune_hf_swin.py
index 07b68d2d..45c4eb2d 100644
--- a/examples/transformers/prune_hf_swin.py
+++ b/examples/transformers/prune_hf_swin.py
@@ -48,15 +48,15 @@ def get_in_channels(self, layer):
print(model)
imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
-channel_groups = {}
+num_heads = {}
ignored_layers = [model.classifier]
# All heads should be pruned simultaneously, so we group channels by head.
for m in model.modules():
if isinstance(m, SwinSelfAttention):
- channel_groups[m.query] = m.num_attention_heads
- channel_groups[m.key] = m.num_attention_heads
- channel_groups[m.value] = m.num_attention_heads
+ num_heads[m.query] = m.num_attention_heads
+ num_heads[m.key] = m.num_attention_heads
+ num_heads[m.value] = m.num_attention_heads
pruner = tp.pruner.MetaPruner(
model,
@@ -65,7 +65,7 @@ def get_in_channels(self, layer):
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target sparsity
ch_sparsity=0.5,
- channel_groups=channel_groups,
+ num_heads=num_heads,
output_transform=lambda out: out.logits.sum(),
ignored_layers=ignored_layers,
customized_pruners={SwinPatchMerging: SwinPatchMergingPruner()},
diff --git a/examples/transformers/prune_hf_vit.py b/examples/transformers/prune_hf_vit.py
index bc79b481..7636cf8d 100644
--- a/examples/transformers/prune_hf_vit.py
+++ b/examples/transformers/prune_hf_vit.py
@@ -96,14 +96,14 @@ def validate_model(model, val_loader, device):
print("Accuracy: %.4f, Loss: %.4f"%(acc_ori, loss_ori))
print("Pruning %s..."%args.model_name)
-channel_groups = {}
+num_heads = {}
ignored_layers = [model.classifier]
# All heads should be pruned simultaneously, so we group channels by head.
for m in model.modules():
if isinstance(m, ViTSelfAttention):
- channel_groups[m.query] = m.num_attention_heads
- channel_groups[m.key] = m.num_attention_heads
- channel_groups[m.value] = m.num_attention_heads
+ num_heads[m.query] = m.num_attention_heads
+ num_heads[m.key] = m.num_attention_heads
+ num_heads[m.value] = m.num_attention_heads
if args.bottleneck and isinstance(m, ViTSelfOutput):
ignored_layers.append(m.dense)
@@ -115,7 +115,7 @@ def validate_model(model, val_loader, device):
ch_sparsity=args.pruning_ratio, # target sparsity
ignored_layers=ignored_layers,
output_transform=lambda out: out.logits.sum(),
- channel_groups=channel_groups,
+ num_heads=num_heads,
)
if isinstance(imp, tp.importance.TaylorImportance):
diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py
index 3c752c05..af957867 100644
--- a/examples/transformers/prune_timm_vit.py
+++ b/examples/transformers/prune_timm_vit.py
@@ -123,12 +123,12 @@ def main():
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
print("Pruning %s..."%args.model_name)
- ch_groups = {}
+ num_heads = {}
ignored_layers = [model.head]
for m in model.modules():
if isinstance(m, timm.models.vision_transformer.Attention):
m.forward = forward.__get__(m, timm.models.vision_transformer.Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
- ch_groups[m.qkv] = m.num_heads * 3
+ num_heads[m.qkv] = m.num_heads
if args.bottleneck and isinstance(m, timm.models.vision_transformer.Mlp):
ignored_layers.append(m.fc2) # only prune the internal layers of FFN & Attention
@@ -144,7 +144,7 @@ def main():
importance=imp, # importance criterion for parameter selection
ch_sparsity=args.pruning_ratio, # target sparsity
ignored_layers=ignored_layers,
- channel_groups=ch_groups,
+ num_heads=num_heads, # number of heads in self attention
round_to=16,
)
if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh
index 7cdb17bc..f499f27e 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh
@@ -3,8 +3,7 @@ python prune_timm_vit.py \
--pruning_type l1 \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
- --data_path ~/Datasets/shared/imagenet \
- --test_accuracy \
+ --data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/vit_base_patch16_224_pruned_l1_uniform.pth \
\ No newline at end of file
diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py
index 692f135d..4e12bb89 100644
--- a/torch_pruning/pruner/algorithms/metapruner.py
+++ b/torch_pruning/pruner/algorithms/metapruner.py
@@ -52,6 +52,7 @@ def __init__(
# Advanced
in_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer input
out_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer output
+ num_heads: typing.Dict[nn.Module, int] = dict(), # The number of heads for multi-head attention
customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner}
unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0}
root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group
@@ -59,7 +60,7 @@ def __init__(
output_transform: typing.Callable = None, # a function to transform network outputs
# deprecated
- channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers
+ channel_groups: typing.Dict[nn.Module, int] = dict(), # channel grouping
):
self.model = model
self.importance = importance
@@ -72,9 +73,12 @@ def __init__(
warnings.warn("channel_groups is deprecated. Please use in_channel_groups and out_channel_groups instead.")
out_channel_groups.update(channel_groups)
+ if len(num_heads) > 0:
+ out_channel_groups.update(num_heads)
+
self.in_channel_groups = in_channel_groups
self.out_channel_groups = out_channel_groups
-
+ self.num_heads = num_heads
self.root_module_types = root_module_types
self.round_to = round_to
@@ -129,7 +133,7 @@ def __init__(
in_ch_group = layer_pruner.get_in_channel_groups(m)
out_ch_group = layer_pruner.get_out_channel_groups(m)
if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels:
- continue
+ continue
if in_ch_group > 1:
self.in_channel_groups[m] = in_ch_group
if out_ch_group > 1:
@@ -149,7 +153,7 @@ def __init__(
if self.global_pruning:
initial_total_channels = 0
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
- group = self._downstream_node_as_root_if_unbind(group)
+ group = self._downstream_node_as_root_if_attention(group)
initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) )
self.initial_total_channels = initial_total_channels
@@ -209,8 +213,8 @@ def _check_sparsity(self, group) -> bool:
def _get_channel_groups(self, group) -> int:
ch_groups = 1
- has_unbind = False
- unbind_node = None
+ #has_unbind = False
+ #unbind_node = None
for dep, _ in group:
module = dep.target.module
@@ -220,24 +224,24 @@ def _get_channel_groups(self, group) -> int:
if module in channel_groups:
ch_groups = channel_groups[module]
- if dep.source.type==ops.OPTYPE.UNBIND:
- has_unbind = True
- unbind_node = dep.source
+ #if dep.source.type==ops.OPTYPE.UNBIND:
+ # has_unbind = True
+ # unbind_node = dep.source
- if has_unbind and ch_groups>1:
- ch_groups = ch_groups // len(unbind_node.outputs)
+ #if has_unbind and ch_groups>1:
+ # ch_groups = ch_groups // len(unbind_node.outputs)
return ch_groups # no channel grouping
- def _downstream_node_as_root_if_unbind(self, group):
+ def _downstream_node_as_root_if_attention(self, group):
# Use a downstream node as the root if torch.unbind exists. TODO: find a general way to handle torch.unbind in timm
- qkv_unbind = False
+ is_attention = False
downstream_dep = None
for _dep, _idxs in group:
- if _dep.source.type == ops.OPTYPE.UNBIND:
- qkv_unbind = True
+ if _dep.source.module in self.num_heads:
+ is_attention = True
if isinstance(_dep.target.module, tuple(self.root_module_types)):
downstream_dep = _dep
- if qkv_unbind and downstream_dep is not None: # use a downstream node as the root node
+ if is_attention and downstream_dep is not None: # use a downstream node as the root node for attention layers
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs)
return group
@@ -254,7 +258,7 @@ def prune_local(self) -> typing.Generator:
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
if self._check_sparsity(group): # check pruning ratio
- group = self._downstream_node_as_root_if_unbind(group)
+ group = self._downstream_node_as_root_if_attention(group)
module = group[0][0].target.module
pruning_fn = group[0][0].handler
@@ -315,7 +319,7 @@ def prune_global(self) -> typing.Generator:
global_importance = []
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
if self._check_sparsity(group):
- group = self._downstream_node_as_root_if_unbind(group)
+ group = self._downstream_node_as_root_if_attention(group)
ch_groups = self._get_channel_groups(group)
imp = self.estimate_importance(group, ch_groups=ch_groups)
if imp is None: continue
From 6a980834f06dd8d86cb4a8d08e6af6d28983b85c Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 19:57:38 +0800
Subject: [PATCH 12/17] Default ImageNet root: data/imagenet
---
benchmarks/benchmark_importance_criteria.py | 2 +-
examples/transformers/prune_hf_vit.py | 2 +-
examples/transformers/prune_timm_vit.py | 2 +-
examples/transformers/readme.md | 2 +-
.../transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh | 2 +-
.../transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh | 2 +-
.../scripts/finetune_timm_deit_b_16_taylor_uniform.sh | 2 +-
.../scripts/finetune_timm_vit_b_16_hessian_uniform.sh | 2 +-
.../transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh | 2 +-
.../transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh | 2 +-
.../scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh | 2 +-
.../scripts/finetune_timm_vit_b_16_taylor_uniform.sh | 2 +-
examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh | 2 +-
.../transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh | 2 +-
.../transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh | 2 +-
.../transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh | 2 +-
examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh | 1 +
examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh | 2 +-
.../scripts/prune_timm_vit_b_16_taylor_bottleneck.sh | 2 +-
.../transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh | 2 +-
.../scripts/prune_timm_vit_b_16_taylor_uniform_global.sh | 2 +-
examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh | 2 +-
examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh | 2 +-
examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh | 2 +-
24 files changed, 24 insertions(+), 23 deletions(-)
diff --git a/benchmarks/benchmark_importance_criteria.py b/benchmarks/benchmark_importance_criteria.py
index 3cf30161..c2ebc6e8 100644
--- a/benchmarks/benchmark_importance_criteria.py
+++ b/benchmarks/benchmark_importance_criteria.py
@@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
N_batchs = 10
-imagenet_root = '~/Datasets/shared/imagenet/'
+imagenet_root = 'data/imagenet'
print('Parsing dataset...')
train_dst = ImageFolder(os.path.join(imagenet_root, 'train'), transform=T.Compose(
[
diff --git a/examples/transformers/prune_hf_vit.py b/examples/transformers/prune_hf_vit.py
index 7636cf8d..27256239 100644
--- a/examples/transformers/prune_hf_vit.py
+++ b/examples/transformers/prune_hf_vit.py
@@ -14,7 +14,7 @@
parser = argparse.ArgumentParser(description='ViT Pruning')
parser.add_argument('--model_name', default='google/vit-base-patch16-224', type=str, help='model name')
-parser.add_argument('--data_path', default='~/Datasets/shared/imagenet/', type=str, help='model name')
+parser.add_argument('--data_path', default='data/imagenet', type=str, help='model name')
parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio')
parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform')
diff --git a/examples/transformers/prune_timm_vit.py b/examples/transformers/prune_timm_vit.py
index af957867..9b46e838 100644
--- a/examples/transformers/prune_timm_vit.py
+++ b/examples/transformers/prune_timm_vit.py
@@ -16,7 +16,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Timm ViT Pruning')
parser.add_argument('--model_name', default='vit_base_patch16_224', type=str, help='model name')
- parser.add_argument('--data_path', default='~/Datasets/shared/imagenet/', type=str, help='model name')
+ parser.add_argument('--data_path', default='data/imagenet', type=str, help='model name')
parser.add_argument('--taylor_batchs', default=10, type=int, help='number of batchs for taylor criterion')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='prune ratio')
parser.add_argument('--bottleneck', default=False, action='store_true', help='bottleneck or uniform')
diff --git a/examples/transformers/readme.md b/examples/transformers/readme.md
index 0be8c156..49053764 100644
--- a/examples/transformers/readme.md
+++ b/examples/transformers/readme.md
@@ -5,7 +5,7 @@
### Data
Please prepare the ImageNet-1K dataset as follows and modify the data root in the script.
```
-imagenet/
+./data/imagenet/
train/
n01440764/
n01440764_10026.JPEG
diff --git a/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh
index cf746e75..e0ca0127 100644
--- a/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh
+++ b/examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh
@@ -16,6 +16,6 @@ torchrun --nproc_per_node=8 finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/hf_vit_b_16_pruning_l1_uniform \
--is_huggingface \
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh
index bc192e22..edb2440b 100644
--- a/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh
+++ b/examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh
@@ -16,6 +16,6 @@ torchrun --nproc_per_node=8 finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/hf_vit_b_16_pruning_taylor_uniform \
--is_huggingface \
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh b/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh
index bf5dd16d..4358f6f7 100644
--- a/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh
+++ b/examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh
@@ -17,6 +17,6 @@ torchrun --nproc_per_node=8 finetune.py \
--ra-sampler \
--random-erase 0.25 \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/deit_b_16_pruning_taylor_uniform \
--use_imagenet_mean_std \
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh
index 779f14bb..50012b2c 100644
--- a/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh
+++ b/examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh
@@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/vit_b_16_pruning_hessian_uniform
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh
index f66a0e87..95f5e671 100644
--- a/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh
+++ b/examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh
@@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/vit_b_16_pruning_l1_uniform
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh
index 1e764fa2..e233fe7e 100644
--- a/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh
+++ b/examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh
@@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/vit_b_16_pruning_l2_uniform
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh
index 0cfb7555..737bdd13 100644
--- a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh
+++ b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh
@@ -16,5 +16,5 @@ torchrun --nproc_per_node=8 finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/vit_b_16_pruning_taylor_bottleneck
\ No newline at end of file
diff --git a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh
index 888085fd..a0fa0675 100644
--- a/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh
+++ b/examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh
@@ -17,5 +17,5 @@ torchrun --nproc_per_node=8 finetune.py \
--ra-sampler \
--random-erase 0.25 \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--output-dir output/vit_b_16_pruning_taylor_uniform_v2
\ No newline at end of file
diff --git a/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh
index 8a3e3bd7..4656d0e8 100644
--- a/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh
+++ b/examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh
@@ -3,7 +3,7 @@ python prune_hf_vit.py \
--pruning_type l1 \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--test_accuracy \
--train_batch_size 64 \
--val_batch_size 64 \
diff --git a/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh
index 3d606509..6c2ff8e5 100644
--- a/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh
+++ b/examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh
@@ -3,7 +3,7 @@ python prune_hf_vit.py \
--pruning_type taylor \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--test_accuracy \
--train_batch_size 64 \
--val_batch_size 64 \
diff --git a/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh b/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh
index d58e94d6..0040e5ad 100644
--- a/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh
@@ -3,7 +3,7 @@ python prune_timm_vit.py \
--pruning_type taylor \
--pruning_ratio 0.54 \
--taylor_batchs 50 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/deit_base_patch16_224_pruned_taylor_uniform.pth \
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh
index 02d9b29f..39e35671 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh
@@ -4,7 +4,7 @@ python prune_timm_vit.py \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
--test_accuracy \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth \
\ No newline at end of file
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh
index f499f27e..4df87fd5 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh
@@ -4,6 +4,7 @@ python prune_timm_vit.py \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
--data_path data/imagenet \
+ --test_accuracy \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/vit_base_patch16_224_pruned_l1_uniform.pth \
\ No newline at end of file
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh
index cdd51e41..21386e8b 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh
@@ -3,7 +3,7 @@ python prune_timm_vit.py \
--pruning_type l2 \
--pruning_ratio 0.5 \
--taylor_batchs 10 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--test_accuracy \
--train_batch_size 64 \
--val_batch_size 64 \
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh
index 149238c2..4ae381d1 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh
@@ -3,7 +3,7 @@ python prune_timm_vit.py \
--pruning_type taylor \
--pruning_ratio 0.73 \
--taylor_batchs 10 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--bottleneck \
--train_batch_size 64 \
--val_batch_size 64 \
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh
index 2763a364..19589ce0 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh
@@ -3,7 +3,7 @@ python prune_timm_vit.py \
--pruning_type taylor \
--pruning_ratio 0.54 \
--taylor_batchs 50 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth \
\ No newline at end of file
diff --git a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh
index 9173b20b..24672b68 100644
--- a/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh
+++ b/examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh
@@ -3,7 +3,7 @@ python prune_timm_vit.py \
--pruning_type taylor \
--pruning_ratio 0.6 \
--taylor_batchs 10 \
- --data_path ~/Datasets/shared/imagenet \
+ --data_path data/imagenet \
--train_batch_size 64 \
--val_batch_size 64 \
--save_as output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth \
diff --git a/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh b/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh
index 353b0d06..56f8838f 100644
--- a/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh
+++ b/examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh
@@ -16,7 +16,7 @@ python finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--test-only \
--interpolation bilinear \
--is_huggingface \
\ No newline at end of file
diff --git a/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh b/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh
index 9d6e3567..7fe0da2d 100644
--- a/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh
+++ b/examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh
@@ -16,6 +16,6 @@ python finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--test-only \
--use_imagenet_mean_std \
diff --git a/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh b/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh
index 08b021ec..62f4b7dd 100644
--- a/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh
+++ b/examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh
@@ -16,5 +16,5 @@ python finetune.py \
--clip-grad-norm 1 \
--ra-sampler \
--cutmix-alpha 1.0 \
- --data-path "~/Datasets/shared/imagenet/" \
+ --data-path "data/imagenet" \
--test-only \
From 0e03c2b0ca400a488270b883f4d9eda388b0393a Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 20:12:33 +0800
Subject: [PATCH 13/17] Update readme
---
examples/torchvision_models/readme.md | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/examples/torchvision_models/readme.md b/examples/torchvision_models/readme.md
index d198cd95..6d0231a7 100644
--- a/examples/torchvision_models/readme.md
+++ b/examples/torchvision_models/readme.md
@@ -19,14 +19,14 @@ python torchvision_pruning.py
#### Outputs:
```
-Successful Pruning: 81 Models
- ['ssdlite320_mobilenet_v3_large', 'ssd300_vgg16', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fcos_resnet50_fpn', 'keypointrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'retinanet_resnet50_fpn_v2', 'alexnet', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']
+Successful Pruning: 77 Models
+ ['ssdlite320_mobilenet_v3_large', 'ssd300_vgg16', 'fasterrcnn_resnet50_fpn', 'fasterrcnn_resnet50_fpn_v2', 'fasterrcnn_mobilenet_v3_large_320_fpn', 'fasterrcnn_mobilenet_v3_large_fpn', 'fcos_resnet50_fpn', 'keypointrcnn_resnet50_fpn', 'maskrcnn_resnet50_fpn_v2', 'retinanet_resnet50_fpn_v2', 'alexnet', 'vit_b_16', 'vit_b_32', 'vit_l_16', 'vit_l_32', 'vit_h_14', 'convnext_tiny', 'convnext_small', 'convnext_base', 'convnext_large', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l', 'googlenet', 'inception_v3', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_1_6gf', 'regnet_y_3_2gf', 'regnet_y_8gf', 'regnet_y_16gf', 'regnet_y_32gf', 'regnet_y_128gf', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn']
```
```
-Unsuccessful Pruning: 4 Models
- ['raft_large', 'swin_t', 'swin_s', 'swin_b']
+Unsuccessful Pruning: 8 Models
+ ['raft_large', 'swin_t', 'swin_s', 'swin_b', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']
```
#### Vision Transfomer Example
From 7e76187a346950ea86b6b345ff5c13608beeff6e Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 20:54:17 +0800
Subject: [PATCH 14/17] Fixed a bug with torch 2.0
---
tests/test_concat_split.py | 6 +++---
torch_pruning/dependency.py | 6 ++++--
torch_pruning/pruner/algorithms/metapruner.py | 1 -
3 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/tests/test_concat_split.py b/tests/test_concat_split.py
index e7811aa7..b59de7c8 100644
--- a/tests/test_concat_split.py
+++ b/tests/test_concat_split.py
@@ -60,13 +60,13 @@ def test_pruner():
ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
ignored_layers=ignored_layers,
)
- print(model)
+
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
for g in pruner.step(interactive=True):
- print(g.details())
+ #print(g.details())
g.prune()
- print(model)
+ print(model)
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(model(example_inputs).shape)
print(
diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py
index bc791fa4..2caa2fe3 100644
--- a/torch_pruning/dependency.py
+++ b/torch_pruning/dependency.py
@@ -506,7 +506,6 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o
layer_channels = pruner.get_out_channels(m)
group = self.get_pruning_group(
m, pruner.prune_out_channels, list(range(layer_channels)))
-
prunable_group = True
for dep, _ in group:
module = dep.target.module
@@ -852,8 +851,10 @@ def _init_shape_information(self):
if node.type == ops.OPTYPE.SPLIT:
grad_fn = node.grad_fn
- if hasattr(grad_fn, '_saved_self_sizes'):
+
+ if hasattr(grad_fn, '_saved_self_sizes') or hasattr(grad_fn, '_saved_split_sizes'):
MAX_LEGAL_DIM = 100
+
if hasattr(grad_fn, '_saved_split_sizes') and hasattr(grad_fn, '_saved_dim') :
if grad_fn._saved_dim != 1 and grad_fn._saved_dim < MAX_LEGAL_DIM: # a temp fix for pytorch==1.11, where the _saved_dim is an uninitialized value like 118745347895359
continue
@@ -1036,6 +1037,7 @@ def _update_split_index_mapping(self, split_node: Node):
return
offsets = split_node.module.offsets
+
if offsets is None:
return
addressed_dep = []
diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py
index 4e12bb89..e5ba53c6 100644
--- a/torch_pruning/pruner/algorithms/metapruner.py
+++ b/torch_pruning/pruner/algorithms/metapruner.py
@@ -306,7 +306,6 @@ def prune_local(self) -> typing.Generator:
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_idxs.tolist())
-
if self.DG.check_pruning_group(group):
yield group
From 0db771acb6f4bf08b9d867735f3346a91f1e8414 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 20:59:36 +0800
Subject: [PATCH 15/17] Update readme
---
README.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 975c756d..8206e73f 100644
--- a/README.md
+++ b/README.md
@@ -244,7 +244,7 @@ With DepGraph, it is easy to design some "group-level" criteria to estimate the
-#### Modify attributes or forward functions
+#### Modify module attributes or forward function
In some implementation, model forward might rely on some static attributes. For example in [``convformer_s18``](https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/metaformer.py#L107) of timm, we have:
@@ -256,13 +256,13 @@ class Scale(nn.Module):
def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
super().__init__()
- self.shape = (dim, 1, 1) if use_nchw else (dim,)
+ self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning
self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
def forward(self, x):
return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning
```
-where the ```forward``` function relies on ``self.shape`` during forwarding. But, the true ``self.shape`` changed after pruning, which should be manually adjusted accordingly.
+where the ```forward``` function relies on ``self.shape`` during forwarding. But, the true ``self.shape`` changed after pruning, which should be manually updated.
### 3. Save & Load
From 32d843d03a0697ff88ea875876c01c14d882d9d6 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 21:02:29 +0800
Subject: [PATCH 16/17] Rename
---
examples/transformers/{test_latency.py => measure_latency.py} | 0
examples/transformers/readme.md | 2 +-
2 files changed, 1 insertion(+), 1 deletion(-)
rename examples/transformers/{test_latency.py => measure_latency.py} (100%)
diff --git a/examples/transformers/test_latency.py b/examples/transformers/measure_latency.py
similarity index 100%
rename from examples/transformers/test_latency.py
rename to examples/transformers/measure_latency.py
diff --git a/examples/transformers/readme.md b/examples/transformers/readme.md
index 49053764..a059ffce 100644
--- a/examples/transformers/readme.md
+++ b/examples/transformers/readme.md
@@ -90,7 +90,7 @@ wget https://github.com/VainF/Torch-Pruning/releases/download/v1.2.5/vit_b_16_pr
* Measure the latency of the pruned models
```bash
-python estimate_latency.py --model pretrained/vit_b_16_pruning_taylor_uniform.pth
+python measure_latency.py --model pretrained/vit_b_16_pruning_taylor_uniform.pth
```
## Pruning ViT-ImageNet-1K from [HF Transformers](https://huggingface.co/docs/transformers/index)
From 7653e200c8c6d8e85f72bd9ceeb5567cc8a50838 Mon Sep 17 00:00:00 2001
From: VainF <2218880241@qq.com>
Date: Mon, 2 Oct 2023 21:07:52 +0800
Subject: [PATCH 17/17] Update Comments
---
torch_pruning/pruner/algorithms/metapruner.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py
index e5ba53c6..18abdb5c 100644
--- a/torch_pruning/pruner/algorithms/metapruner.py
+++ b/torch_pruning/pruner/algorithms/metapruner.py
@@ -27,11 +27,17 @@ class MetaPruner:
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
# Adavanced
+ * in_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer input. Default: dict().
+ * out_channel_groups (Dict[nn.Module, int]): The number of channel groups for layer output. Default: dict().
+ * num_heads (Dict[nn.Module, int]): The number of heads for multi-head attention. Default: dict().
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
+
+ # Deprecated
+ * channel_groups (Dict[nn.Module, int]): output channel grouping. Default: dict().
"""
def __init__(