Skip to content

Commit

Permalink
Merge pull request #265 from VainF/v1.2
Browse files Browse the repository at this point in the history
V1.2
  • Loading branch information
VainF authored Oct 2, 2023
2 parents 8a9c81f + 7653e20 commit 474bfa6
Show file tree
Hide file tree
Showing 39 changed files with 208 additions and 381 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ torch_pruning_bak
benchmarks/run
output
acc.png
pretrained
pretrained
data
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,25 @@ With DepGraph, it is easy to design some "group-level" criteria to estimate the
<img src="https://github.com/VainF/Torch-Pruning/assets/18592211/11473499-d28a-434b-a8d6-1a53c4b3b7c0" width="45%"></img>
</div>

#### Modify module attributes or forward function

In some implementation, model forward might rely on some static attributes. For example in [``convformer_s18``](https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/metaformer.py#L107) of timm, we have:

```python
class Scale(nn.Module):
"""
Scale vector by element multiplications.
"""

def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
super().__init__()
self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning
self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)

def forward(self, x):
return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning
```
where the ```forward``` function relies on ``self.shape`` during forwarding. But, the true ``self.shape`` changed after pruning, which should be manually updated.


### 3. Save & Load
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_importance_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt

N_batchs = 10
imagenet_root = '~/Datasets/shared/imagenet/'
imagenet_root = 'data/imagenet'
print('Parsing dataset...')
train_dst = ImageFolder(os.path.join(imagenet_root, 'train'), transform=T.Compose(
[
Expand Down
65 changes: 65 additions & 0 deletions examples/timm_models/prune_timm_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparse

parser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--ch_sparsity', default=0.5, type=float, help='channel sparsity')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()

def main():
timm_models = timm.list_models()
if args.list_models:
print(timm_models)
if args.model is None:
return
assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)

imp = tp.importance.GroupNormImportance()
print("Pruning %s..."%args.model)

input_size = model.default_cfg['input_size']
example_inputs = torch.randn(1, *input_size).to(device)
test_output = model(example_inputs)
ignored_layers = []
for m in model.modules():
if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
ignored_layers.append(m)
print("Ignore classifier layer: ", m)

print("========Before pruning========")
print(model)
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
global_pruning=args.global_pruning, # If False, a uniform sparsity will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target sparsity
ch_sparsity=args.ch_sparsity, # target sparsity
ignored_layers=ignored_layers,
)
for g in pruner.step(interactive=True):
g.prune()

print("========After pruning========")
print(model)
test_output = model(example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

if __name__=='__main__':
main()
50 changes: 37 additions & 13 deletions examples/timm_models/readme.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,53 @@
# Pruning Models from Timm

## 0. Requirements

## 0. List all models in Timm

```bash
pip install -r requirements.txt
python prune_timm_models.py --list_models
```
Tested environment:

Output:
```
pytorch==1.12.1
timm=0.9.2
['bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', ...]
```

## 1. Pruning

```python
python timm_pruning.py
Some models might require additional modifications to enable pruning. For example, we need to reimplement the forward function of `vit` to relax the constraint in structure. Refer to [examples/transformers/prune_timm_vit.py](../transformers/prune_timm_vit.py) for more details.

```bash
python prune_timm_models.py --model convnext_xxlarge --ch_sparsity 0.5 # --global_pruning
```

#### Outputs:
Prunable: 119 models,
```
['beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'convmixer_768_32', 'convmixer_1024_20_ks9_p14', 'convmixer_1536_20', 'convnext_atto', 'convnext_atto_ols', 'convnext_base', 'convnext_femto', 'convnext_femto_ols', 'convnext_large', 'convnext_large_mlp', 'convnext_nano', 'convnext_nano_ols', 'convnext_pico', 'convnext_pico_ols', 'convnext_small', 'convnext_tiny', 'convnext_tiny_hnf', 'convnext_xlarge', 'convnext_xxlarge', 'convnextv2_atto', 'convnextv2_base', 'convnextv2_femto', 'convnextv2_huge', 'convnextv2_large', 'convnextv2_nano', 'convnextv2_pico', 'convnextv2_small', 'convnextv2_tiny', 'darknet17', 'darknet21', 'darknet53', 'darknetaa53', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'densenet264d', 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'eca_botnext26ts_256', 'eca_resnet33ts', 'eca_resnext26ts', 'eca_vovnet39b', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet200d', 'ecaresnet269d', 'ecaresnetlight', 'ecaresnext26t_32x4d', 'ecaresnext50t_32x4d', 'efficientnet_b0', 'efficientnet_b0_g8_gn', 'efficientnet_b0_g16_evos', 'efficientnet_b0_gn', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2_pruned', 'efficientnet_b2a', 'efficientnet_b3', 'efficientnet_b3_gn', 'efficientnet_b3_pruned', 'efficientnet_b3a', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_el', 'efficientnet_el_pruned', 'efficientnet_em', 'efficientnet_es', 'efficientnet_es_pruned', 'efficientnet_l2', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'efficientnetv2_l', 'efficientnetv2_m', 'efficientnetv2_rw_m', 'efficientnetv2_rw_s', 'efficientnetv2_rw_t', 'efficientnetv2_s', 'efficientnetv2_xl', 'ese_vovnet19b_dw', 'ese_vovnet19b_slim', 'ese_vovnet19b_slim_dw', 'ese_vovnet39b', 'ese_vovnet57b', 'ese_vovnet99b', 'fbnetc_100', 'fbnetv3_b', 'fbnetv3_d', 'fbnetv3_g', 'gc_efficientnetv2_rw_t', 'gcresnet33ts']
```
========Before pruning========
...
(norm_pre): Identity()
(head): NormMlpClassifierHead(
(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
(norm): LayerNorm2d((3072,), eps=1e-05, elementwise_affine=True)
(flatten): Flatten(start_dim=1, end_dim=-1)
(pre_logits): Identity()
(drop): Dropout(p=0.0, inplace=False)
(fc): Linear(in_features=3072, out_features=1000, bias=True)
)
)
Unprunable: 175 models,
```
['bat_resnext26ts', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_small', 'coat_tiny', 'coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224', 'coatnext_nano_rw_224', 'convformer_b36', 'convformer_m36', 'convformer_s18', 'convformer_s36', 'convit_base', 'convit_small', 'convit_tiny', 'crossvit_9_240', 'crossvit_9_dagger_240', 'crossvit_15_240', 'crossvit_15_dagger_240', 'crossvit_15_dagger_408', 'crossvit_18_240', 'crossvit_18_dagger_240', 'crossvit_18_dagger_408', 'crossvit_base_240', 'crossvit_small_240', 'crossvit_tiny_240', 'cs3darknet_focus_l', 'cs3darknet_focus_m', 'cs3darknet_focus_s', 'cs3darknet_focus_x', 'cs3darknet_l', 'cs3darknet_m', 'cs3darknet_s', 'cs3darknet_x', 'cs3edgenet_x', 'cs3se_edgenet_x', 'cs3sedarknet_l', 'cs3sedarknet_x', 'cs3sedarknet_xdw', 'cspdarknet53', 'cspresnet50', 'cspresnet50d', 'cspresnet50w', 'cspresnext50', 'davit_base', 'davit_giant', 'davit_huge', 'davit_large', 'davit_small', 'davit_tiny', 'deit3_base_patch16_224', 'deit3_base_patch16_384', 'deit3_huge_patch14_224', 'deit3_large_patch16_224', 'deit3_large_patch16_384', 'deit3_medium_patch16_224', 'deit3_small_patch16_224', 'deit3_small_patch16_384', 'deit_base_distilled_patch16_224', 'deit_base_distilled_patch16_384', 'deit_base_patch16_224', 'deit_base_patch16_384', 'deit_small_distilled_patch16_224', 'deit_small_patch16_224', 'deit_tiny_distilled_patch16_224', 'deit_tiny_patch16_224', 'densenetblur121d', 'dla60_res2net', 'dla60_res2next', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'dpn48b', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'eca_halonext26ts', 'eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3', 'edgenext_base', 'edgenext_small', 'edgenext_small_rw', 'edgenext_x_small', 'edgenext_xx_small', 'efficientformer_l1', 'efficientformer_l3', 'efficientformer_l7', 'efficientformerv2_l', 'efficientformerv2_s0', 'efficientformerv2_s1', 'efficientformerv2_s2', 'efficientnet_b3_g8_gn', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'ese_vovnet39b_evos', 'eva02_base_patch14_224', 'eva02_base_patch14_448', 'eva02_base_patch16_clip_224', 'eva02_enormous_patch14_clip_224', 'eva02_large_patch14_224', 'eva02_large_patch14_448', 'eva02_large_patch14_clip_224', 'eva02_large_patch14_clip_336', 'eva02_small_patch14_224', 'eva02_small_patch14_336', 'eva02_tiny_patch14_224', 'eva02_tiny_patch14_336', 'eva_giant_patch14_224', 'eva_giant_patch14_336', 'eva_giant_patch14_560', 'eva_giant_patch14_clip_224', 'eva_large_patch14_196', 'eva_large_patch14_336', 'flexivit_base', 'flexivit_large', 'flexivit_small', 'focalnet_base_lrf', 'focalnet_base_srf', 'focalnet_huge_fl3', 'focalnet_huge_fl4', 'focalnet_large_fl3', 'focalnet_large_fl4', 'focalnet_small_lrf', 'focalnet_small_srf', 'focalnet_tiny_lrf', 'focalnet_tiny_srf', 'focalnet_xlarge_fl3', 'focalnet_xlarge_fl4']
========After pruning========
...
(norm_pre): Identity()
(head): NormMlpClassifierHead(
(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Identity())
(norm): LayerNorm2d((1536,), eps=1e-05, elementwise_affine=True)
(flatten): Flatten(start_dim=1, end_dim=-1)
(pre_logits): Identity()
(drop): Dropout(p=0.0, inplace=False)
(fc): Linear(in_features=1536, out_features=1000, bias=True)
)
)
MACs: 197.9920 G => 49.7716 G
Params: 846.4710 M => 213.2587 M
```
123 changes: 0 additions & 123 deletions examples/timm_models/timm_beit.py

This file was deleted.

64 changes: 0 additions & 64 deletions examples/timm_models/timm_pruning.py

This file was deleted.

Loading

0 comments on commit 474bfa6

Please sign in to comment.