Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/VainF/Torch-Pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Mar 29, 2023
2 parents 57a7fcd + 3d031ca commit cc3d69d
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@ This example demonstrates the fundamental pruning pipeline using DepGraph. Note
For more details about grouping, please refer to [tutorials/2 - Exploring Dependency Groups](https://github.com/VainF/Torch-Pruning/blob/master/tutorials/2%20-%20Exploring%20Dependency%20Groups.ipynb)

#### How to scan all groups:
Just like what we do in the [MetaPruner](https://github.com/VainF/Torch-Pruning/blob/b607ae3aa61b9dafe19d2c2364f7e4984983afbf/torch_pruning/pruner/algorithms/metapruner.py#L197), one can use ``DG.get_all_groups(ignored_layers, root_module_types)`` to iterate all groups. Specifically, all groups will all begin with a layer that matches the type specified by the "root_module_types" parameter. These groups contain full ``idxs`` that covers all prunable parameters. If you intend to prune particular channels/dimensions, you can create a new idxs list and re-generate the group. In future versions, we will introduce a more user-friendly interface like ``group.prune(idxs=[2,4,6])``.
Just like what we do in the [MetaPruner](https://github.com/VainF/Torch-Pruning/blob/b607ae3aa61b9dafe19d2c2364f7e4984983afbf/torch_pruning/pruner/algorithms/metapruner.py#L197), one can use ``DG.get_all_groups(ignored_layers, root_module_types)`` to iterate all groups. Specifically, all groups will begin with a layer that matches the type specified by the "root_module_types" parameter. These groups contain a full index list ``idxs=[0,1,2,3,...,K]`` that covers all prunable parameters. If you are intended to prune partial channels/dimensions, you can use ``group.prune(idxs=idxs)``.

```python
for group in DG.get_all_groups():
for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
# handle groups in sequential order
idxs = [2,4,6] # my pruning indices
root_module = group[0].dep.target.module # the root module of this group
pruning_fn = group[0].dep.handler # the pruning function
group = DG.get_pruning_group(root_module, pruning_fn, idxs) # get a group with desired pruning idxs
idxs = [2,4,6] # your pruning indices
group.prune(idxs=idxs)
print(group)
```

Expand Down

0 comments on commit cc3d69d

Please sign in to comment.