Skip to content

Commit

Permalink
Merge pull request #315 from VainF/v1.3
Browse files Browse the repository at this point in the history
V1.3
  • Loading branch information
VainF authored Dec 15, 2023
2 parents 1813e7a + b10e766 commit dedd368
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 21 deletions.
21 changes: 12 additions & 9 deletions examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ def my_prune(model, example_inputs, output_transform, model_name):

for g in pruner.step(interactive=True):
g.prune()

# or
# pruner.step()

if isinstance(pruner, (tp.pruner.BNScalePruner, tp.pruner.GroupNormPruner, tp.pruner.GrowingRegPruner)):
pruner.update_regularizor() # if the model has been pruned, we need to update the regularizor
pruner.regularize(model)

if isinstance(
Expand Down Expand Up @@ -304,14 +307,14 @@ def my_prune(model, example_inputs, output_transform, model_name):
else:
output_transform = None

try:
my_prune(
model, example_inputs=example_inputs, output_transform=output_transform, model_name=model_name
)
successful.append(model_name)
except Exception as e:
print(e)
unsuccessful.append(model_name)
#try:
my_prune(
model, example_inputs=example_inputs, output_transform=output_transform, model_name=model_name
)
successful.append(model_name)
#except Exception as e:
# print(e)
# unsuccessful.append(model_name)
print("Successful Pruning: %d Models\n"%(len(successful)), successful)
print("")
print("Unsuccessful Pruning: %d Models\n"%(len(unsuccessful)), unsuccessful)
Expand Down
1 change: 1 addition & 0 deletions tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_pruner():
grad_dict[p] = p.grad.clone()
else:
grad_dict[p] = None
pruner.update_regularizor()
pruner.regularize(model)
for name, p in model.named_parameters():
if p.grad is not None and grad_dict[p] is not None:
Expand Down
6 changes: 2 additions & 4 deletions torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,9 @@ def __init__(
if self.group_lasso:
self._l2_imp = MagnitudeImportance(p=2, group_reduction='mean', normalizer=None, target_types=[nn.modules.batchnorm._BatchNorm])

def step(self, interactive=False):
yield from super(BNScalePruner, self).step(interactive=interactive)
# Update the group list after pruning
def update_regularizor(self):
self._groups = list(self.DG.get_all_groups(root_module_types=self.root_module_types, ignored_layers=self.ignored_layers))

def regularize(self, model, reg=None, bias=False):
if reg is None:
reg = self.reg # use the default reg
Expand Down
6 changes: 2 additions & 4 deletions torch_pruning/pruner/algorithms/group_norm_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ def __init__(
self._groups = list(self.DG.get_all_groups(root_module_types=self.root_module_types, ignored_layers=self.ignored_layers))
self.cnt = 0

def step(self, interactive=False):
yield from super(GroupNormPruner, self).step(interactive=interactive)
# update the group list after pruning
def update_regularizor(self):
self._groups = list(self.DG.get_all_groups(root_module_types=self.root_module_types, ignored_layers=self.ignored_layers))

@torch.no_grad()
def regularize(self, model, alpha=2**4, bias=False):
for i, group in enumerate(self._groups):
Expand Down
5 changes: 1 addition & 4 deletions torch_pruning/pruner/algorithms/growing_reg_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,13 @@ def update_reg(self):
reg = reg + self.delta_reg * standarized_imp.to(reg.device)
self.group_reg[group] = reg

def step(self, interactive=False):
yield from super(GrowingRegPruner, self).step(interactive=interactive)

def update_regularizor(self):
# Update the group list after pruning
self._groups = list(self.DG.get_all_groups(root_module_types=self.root_module_types, ignored_layers=self.ignored_layers))
group_reg = {}
for group in self._groups:
group_reg[group] = torch.ones(len(group[0].idxs)) * self.base_reg
self.group_reg = group_reg


def regularize(self, model, bias=False):
for i, group in enumerate(self._groups):
Expand Down
3 changes: 3 additions & 0 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def get_target_head_pruning_ratio(self, module) -> float:
def reset(self) -> None:
self.current_step = 0

def update_regularizer(self) -> None:
pass

def regularize(self, model, loss) -> typing.Any:
""" Model regularizor for sparse training
"""
Expand Down

0 comments on commit dedd368

Please sign in to comment.