Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

[NNCF] semantic seg pruning #90

Open
wants to merge 7 commits into
base: ote
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mmseg/apis/export.py
Original file line number Diff line number Diff line change
@@ -157,7 +157,7 @@ def _get_mo_cmd():


def export_to_openvino(cfg, onnx_model_path, output_dir_path, input_shape=None,
input_format='rgb', precision='FP32'):
input_format='rgb', precision='FP32', pruning_transformation=False):
cfg.model.pretrained = None
cfg.data.test.test_mode = True

@@ -192,7 +192,8 @@ def export_to_openvino(cfg, onnx_model_path, output_dir_path, input_shape=None,
if normalize['to_rgb'] and input_format.lower() == 'bgr' or \
not normalize['to_rgb'] and input_format.lower() == 'rgb':
command_line.append('--reverse_input_channels')

if pruning_transformation:
command_line.extend(['--transform', 'Pruning'])
run(command_line, shell=False, check=True)


@@ -239,7 +240,7 @@ def _get_fake_inputs(input_shape, num_classes):


def export_model(model, config, output_dir, target='openvino', onnx_opset=11,
input_format='rgb', precision='FP32', output_logits=False):
input_format='rgb', precision='FP32', output_logits=False, pruning_transformation=False):
assert onnx_opset in available_opsets

if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
@@ -277,6 +278,7 @@ def export_model(model, config, output_dir, target='openvino', onnx_opset=11,
output_dir,
input_shape,
input_format,
precision)
precision,
pruning_transformation=pruning_transformation)
else:
check_onnx_model(onnx_model_path)
16 changes: 13 additions & 3 deletions mmseg/apis/train.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@
from mmseg.models import build_params_manager
from mmseg.models.losses import MarginCalibrationLoss
from mmseg.integration.nncf import wrap_nncf_model
from mmseg.integration.nncf import AccuracyAwareLrUpdater
from mmseg.integration.nncf import is_accuracy_aware_training_set
from mmseg.apis.fake_input import get_fake_input
from mmseg.integration.nncf import CompressionHook
@@ -243,12 +244,19 @@ def train_segmentor(model,

# register training hooks
runner.register_training_hooks(
cfg.lr_config,
None,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None)
)
# register lr updater hook
policy_type = cfg.lr_config.pop('policy')
if policy_type == policy_type.lower():
policy_type = policy_type.title()
cfg.lr_config['type'] = policy_type + 'LrUpdaterHook'
lr_updater_hook = build_from_cfg(cfg.lr_config, HOOKS)
runner.register_lr_hook(lr_updater_hook)

# register parameters manager hook
params_manager_cfg = cfg.get('params_config', None)
@@ -292,7 +300,8 @@ def train_segmentor(model,
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), f'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'

if nncf_is_acc_aware_training_set and hook_cfg.get('type') == 'EarlyStoppingHook':
continue
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
@@ -321,7 +330,8 @@ def train_segmentor(model,
if nncf_is_acc_aware_training_set:
def configure_optimizers_fn():
optimizer = build_optimizer(runner.model, cfg.optimizer)
return optimizer, None
lr_scheduler = AccuracyAwareLrUpdater(lr_updater_hook, runner, optimizer)
return optimizer, lr_scheduler

runner.run(
data_loaders,
2 changes: 2 additions & 0 deletions mmseg/integration/nncf/__init__.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
is_checkpoint_nncf,
wrap_nncf_model,
is_state_nncf,
AccuracyAwareLrUpdater,
)

from .compression_hooks import CompressionHook, CheckpointHookBeforeTraining
@@ -23,6 +24,7 @@

__all__ = [
'AccuracyAwareRunner',
'AccuracyAwareLrUpdater',
'CheckpointHookBeforeTraining',
'CompressionHook',
'check_nncf_is_enabled',
21 changes: 21 additions & 0 deletions mmseg/integration/nncf/compression.py
Original file line number Diff line number Diff line change
@@ -280,3 +280,24 @@ def get_uncompressed_model(module):
if isinstance(module, NNCFNetwork):
return module.get_nncf_wrapped_model()
return module


class AccuracyAwareLrUpdater:
def __init__(self, lr_hook, runner, optimizer=None):
self._lr_hook = lr_hook
self._runner = runner
if optimizer:
runner.optimizer = optimizer
self._lr_hook.before_run(runner)
self._lr_hook.warmup_iters = 0

def step(self, *args, **kwargs):
pass

@property
def base_lrs(self):
return self._lr_hook.base_lr

@base_lrs.setter
def base_lrs(self, value):
self._lr_hook.base_lr = value
2 changes: 2 additions & 0 deletions mmseg/integration/nncf/compression_hooks.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@ def after_train_iter(self, runner):

def after_train_epoch(self, runner):
self.compression_ctrl.scheduler.epoch_step()
if runner.rank == 0:
runner.logger.info(self.compression_ctrl.statistics().to_str())

def before_run(self, runner):
if runner.rank == 0:
6 changes: 3 additions & 3 deletions mmseg/integration/nncf/runners.py
Original file line number Diff line number Diff line change
@@ -38,16 +38,16 @@ def run(self, data_loaders, *args, compression_ctrl=None,
get_host_info(), work_dir)
self.logger.warning('Note that the workflow and max_epochs parameters '
'are not used in NNCF-based accuracy-aware training')
acc_aware_training_loop = create_accuracy_aware_training_loop(nncf_config,
compression_ctrl, verbose=False)
# taking only the first data loader for NNCF training
self.train_data_loader = data_loaders[0]
# Maximum possible number of iterations, needs for progress tracking
self._max_epochs = nncf_config["accuracy_aware_training"]["params"]["maximal_total_epochs"]
self._max_epochs = acc_aware_training_loop.runner.maximal_total_epochs
self._max_iters = self._max_epochs * len(self.train_data_loader)

self.call_hook('before_run')

acc_aware_training_loop = create_accuracy_aware_training_loop(nncf_config,
compression_ctrl)
model = acc_aware_training_loop.run(self.model,
train_epoch_fn=self.train_fn,
validate_fn=self.validation_fn,