-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquantize.py
238 lines (219 loc) · 10.5 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
import sys
import dotenv
import hydra
import torch
import torch.nn as nn
import pandas as pd
from glob import glob
from hydra import compose, initialize
from sqp_experiments.utils import register_resolvers, train_valid_split, compute_metrics, seed_everything
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
def get_hydra_config(overrides=None):
if overrides is None:
overrides = []
# Setup Hydra
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config")
config = compose(config_name="train.yaml", overrides=["workers=0", *overrides])
register_resolvers()
return config
def load_dataloaders(calibr_percentage=0.2):
config = get_hydra_config()
# instantiate dataset
dataset = hydra.utils.instantiate(config.dataset)
dataset_train, dataset_valid = train_valid_split(dataset=dataset, valid_split=config.train_valid_split)
calibr_indeces = torch.randperm(len(dataset_train))[:int(calibr_percentage*len(dataset_train))].numpy()
dataset_calibr = Subset(dataset_train, calibr_indeces)
dataset_test = hydra.utils.instantiate(config.dataset_test)
# instantiate dataloaders
dataloader_valid = DataLoader(dataset=dataset_valid, batch_size=config.valid_batch_size, num_workers=config.workers, shuffle=False)
dataloader_calibr = DataLoader(dataset=dataset_calibr, batch_size=config.batch_size, num_workers=config.workers, shuffle=True)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=config.valid_batch_size, num_workers=config.workers, shuffle=False)
# show datasets infos
print(f"Train dataset: {len(dataset_train)} samples")
print(f"Valid dataset: {len(dataset_valid)} samples")
print(f"Test dataset: {len(dataset_test)} samples")
print(f"Calibr dataset: {len(dataset_calibr)} samples ({calibr_percentage*100}% of train dataset)")
return dataloader_calibr, dataloader_valid, dataloader_test
def load_model(model_weights, model_name):
config = get_hydra_config([f"model={model_name}"])
# instantiate model
model_fp = hydra.utils.instantiate(config.model)
# load weights
state_dict = torch.load(model_weights, map_location='cpu')
missing_keys, unexpected_keys = model_fp.load_state_dict(state_dict, strict=False)
if len(missing_keys) > 0:
print(f"Missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
print(f"Unexpected keys: {unexpected_keys}")
return model_fp
def quantize_model(model_fp, dataloader_calibr, backend="x86", _replace_heaviside=False):
model_fp.eval()
# manual same-padding
model_fp.conv1[0].padding = (1,1)
model_fp.conv2[0].padding = (1,1)
model_fp.conv3[0].padding = (1,1)
model_fp.conv4[0].padding = (1,1)
if _replace_heaviside:
model_fp.conv1 = replace_heaviside(model_fp.conv1)
model_fp.conv2 = replace_heaviside(model_fp.conv2)
model_fp.conv3 = replace_heaviside(model_fp.conv3)
model_fp.conv4 = replace_heaviside(model_fp.conv4)
# quantization wrappers etc
model_fp.qconfig = torch.ao.quantization.get_default_qconfig(backend)
model_fp_wrapped = torch.ao.quantization.QuantWrapper(model_fp)
model_fp_prepared = torch.ao.quantization.prepare(model_fp_wrapped, inplace=False)
model_fp_prepared
# calibrate
for data in tqdm(dataloader_calibr):
input_data, target_data = data
# preprocess data
melspec_input = model_fp.preproc(input_data)
log_melspec_input = torch.log(melspec_input + model_fp.eps)
# process minibatch
_ = model_fp_prepared(log_melspec_input)
model_int = torch.ao.quantization.convert(model_fp_prepared, inplace=False)
return model_int
def test(model, dataloader):
model.eval()
pred_list = []
true_list = []
eps = model.module.eps if isinstance(model, torch.ao.quantization.QuantWrapper) else model.eps
# loop over data (check for ctrl+c)
try:
for data in tqdm(dataloader, colour="green"):
input_data, target_data = data
with torch.no_grad():
# Preprocess data
if isinstance(model, torch.ao.quantization.QuantWrapper):
melspec_input = model.module.preproc(input_data)
else:
melspec_input = model.preproc(input_data)
log_melspec_input = torch.log(melspec_input + eps)
# Run inference
mos_pred = model(log_melspec_input)
mos_true = target_data
# store predictions and targets
pred_list.append(mos_pred.cpu())
true_list.append(mos_true.cpu())
except KeyboardInterrupt:
print("User interrupted testing, returning")
return {}
# reshape predictions and target
pred_tensor = torch.cat(pred_list)[:,0]
true_tensor = torch.cat(true_list)
# compute metrics
avg_metrics = compute_metrics(pred_tensor, true_tensor, colour="green")
return avg_metrics, pred_tensor, true_tensor
def get_weights_paths(model_type):
root_dir = 'trained_models'
models_root = os.path.join(root_dir, model_type)
weights_list = glob(os.path.join(models_root, "*", "pytorch_model.bin"), recursive=True)
weights_list = sorted(weights_list)
print(f"Found {len(weights_list)} weights: {weights_list}")
return weights_list
def replace_heaviside(seq_block, k=1000.0):
seq_block_layers = list(seq_block.children())
seq_block_layers[1] = nn.Hardsigmoid()
# use batchnorm as scalar multiplier
batch_norm = nn.BatchNorm2d(seq_block_layers[0].out_channels)
batch_norm.weight.data.fill_(k) # Multiplier
batch_norm.bias.data.fill_(0.0) # Bias
batch_norm.eval() # Disable running statistics update
for param_name, param in batch_norm.named_parameters():
param.requires_grad = False
# insert batchnorm before hardsigmoid
seq_block_layers.insert(1, batch_norm)
# re-create sequential
seq_block = nn.Sequential(*seq_block_layers)
return seq_block
def main(evaluation_set):
if evaluation_set not in ["valid", "test"]:
raise ValueError(f"evaluation_set must be either 'valid' or 'test', got '{evaluation_set}'")
seed_everything(42)
dotenv.load_dotenv(override=True)
# load dataloaders
dataloader_calibr, dataloader_valid, dataloader_test = load_dataloaders()
print(f"Using {evaluation_set} dataloader for evaluation")
# pick what evaluation set to use
dataloader_current = {
"valid": dataloader_valid,
"test": dataloader_test,
}[evaluation_set]
# for each baseline model, load weights, binarize/quantize and test
weights_list = get_weights_paths("baseline")
bl_metrics = {}
ptb_metrics = {}
ptq_metrics = {}
for i, weights_path in enumerate(weights_list):
print(f"Processing model {i}...")
# quantize
model_fp = load_model(weights_path, "dnsmos_baseline_nolog")
curr_bl_metrics, bl_pred_tensor, bl_true_tensor = test(model_fp, dataloader_current)
bl_metrics[i] = curr_bl_metrics
print(f"Baseline model {i} metrics: {curr_bl_metrics}")
model_ptq = quantize_model(model_fp, dataloader_calibr)
curr_ptq_metrics, ptq_pred_tensor, ptq_true_tensor = test(model_ptq, dataloader_current)
ptq_metrics[i] = curr_ptq_metrics
print(f"Post-training quantized model {i} metrics: {curr_ptq_metrics}")
# binarize
model_ptb = load_model(weights_path, "dnsmos_binary_nolog")
curr_ptb_metrics, ptb_pred_tensor, ptb_true_tensor = test(model_ptb, dataloader_current)
ptb_metrics[i] = curr_ptb_metrics
print(f"Post-training binarized model {i} metrics: {curr_ptb_metrics}")
# save metrics
bl_metrics_df = pd.DataFrame.from_dict(bl_metrics, orient='index')
ptb_metrics_df = pd.DataFrame.from_dict(ptb_metrics, orient='index')
ptq_metrics_df = pd.DataFrame.from_dict(ptq_metrics, orient='index')
if not os.path.exists(f"./binarize_results/{evaluation_set}"):
print(f"Creating directory ./binarize_results/{evaluation_set}")
os.makedirs(f"./binarize_results/{evaluation_set}", exist_ok=True)
bl_metrics_df.to_pickle(f"./binarize_results/{evaluation_set}/bl_metrics.pkl")
ptb_metrics_df.to_pickle(f"./binarize_results/{evaluation_set}/ptb_metrics.pkl")
ptq_metrics_df.to_pickle(f"./binarize_results/{evaluation_set}/ptq_metrics.pkl")
# for each binary-activation model, load weights, quantize and test
weights_list = get_weights_paths("binact")
bin_metrics = {}
qbin_metrics = {}
for i, weights_path in enumerate(weights_list):
print(f"Processing model {i}...")
model_bin = load_model(weights_path, "dnsmos_binary_nolog")
curr_bin_metrics, bin_pred_tensor, bin_true_tensor = test(model_bin, dataloader_current)
bin_metrics[i] = curr_bin_metrics
print(f"Binary-activation model {i} metrics: {curr_bin_metrics}")
# quantize
model_qbin = quantize_model(model_bin, dataloader_calibr, _replace_heaviside=True)
curr_qbin_metrics, qbin_pred_tensor, qbin_true_tensor = test(model_qbin, dataloader_current)
qbin_metrics[i] = curr_qbin_metrics
print(f"Post-training quantized binary-activation model {i} metrics: {curr_qbin_metrics}")
# save metrics
bin_metrics_df = pd.DataFrame.from_dict(bin_metrics, orient='index')
qbin_metrics_df = pd.DataFrame.from_dict(qbin_metrics, orient='index')
bin_metrics_df.to_pickle(f"./binarize_results/{evaluation_set}/bin_metrics.pkl")
qbin_metrics_df.to_pickle(f"./binarize_results/{evaluation_set}/qbin_metrics.pkl")
# combine metrics
combined_metrics = {
"baseline": bl_metrics_df,
"baseline_bin": ptb_metrics_df,
"baseline_ptq": ptq_metrics_df,
"binact": bin_metrics_df,
"binact_ptq": qbin_metrics_df,
}
for model_name, metrics_df in combined_metrics.items():
metrics_df["model"] = model_name
combined_metrics_df = pd.concat(combined_metrics.values(), ignore_index=True)
combined_metrics_df.to_pickle(f"./binarize_results/{evaluation_set}/combined_metrics.pkl")
# save predictions and targets
torch.save({
"baseline": (bl_pred_tensor, bl_true_tensor),
"baseline_bin": (ptb_pred_tensor, ptb_true_tensor),
"baseline_ptq": (ptq_pred_tensor, ptq_true_tensor),
"binact": (bin_pred_tensor, bin_true_tensor),
"binact_ptq": (qbin_pred_tensor, qbin_true_tensor),
}, f"./binarize_results/{evaluation_set}/preds_targets.pt")
print("Done!")
if __name__ == "__main__":
evaluation_set = sys.argv[1]
main(evaluation_set)