-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathcolour_mnist.py
190 lines (147 loc) · 7.69 KB
/
colour_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""ReBias
Copyright (c) 2020-present NAVER Corp.
MIT license
Python implementation of Biased-MNIST.
"""
import os
import numpy as np
from PIL import Image
import torch
from torch.utils import data
from torchvision import transforms
from torchvision.datasets import MNIST
class BiasedMNIST(MNIST):
"""A base class for Biased-MNIST.
We manually select ten colours to synthetic colour bias. (See `COLOUR_MAP` for the colour configuration)
Usage is exactly same as torchvision MNIST dataset class.
You have two paramters to control the level of bias.
Parameters
----------
root : str
path to MNIST dataset.
data_label_correlation : float, default=1.0
Here, each class has the pre-defined colour (bias).
data_label_correlation, or `rho` controls the level of the dataset bias.
A sample is coloured with
- the pre-defined colour with probability `rho`,
- coloured with one of the other colours with probability `1 - rho`.
The number of ``other colours'' is controlled by `n_confusing_labels` (default: 9).
Note that the colour is injected into the background of the image (see `_binary_to_colour`).
Hence, we have
- Perfectly biased dataset with rho=1.0
- Perfectly unbiased with rho=0.1 (1/10) ==> our ``unbiased'' setting in the test time.
In the paper, we explore the high correlations but with small hints, e.g., rho=0.999.
n_confusing_labels : int, default=9
In the real-world cases, biases are not equally distributed, but highly unbalanced.
We mimic the unbalanced biases by changing the number of confusing colours for each class.
In the paper, we use n_confusing_labels=9, i.e., during training, the model can observe
all colours for each class. However, you can make the problem harder by setting smaller n_confusing_labels, e.g., 2.
We suggest to researchers considering this benchmark for future researches.
"""
COLOUR_MAP = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [225, 225, 0], [225, 0, 225],
[0, 255, 255], [255, 128, 0], [255, 0, 128], [128, 0, 255], [128, 128, 128]]
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False, data_label_correlation=1.0, n_confusing_labels=9):
super().__init__(root, train=train, transform=transform,
target_transform=target_transform,
download=download)
self.random = True
self.data_label_correlation = data_label_correlation
self.n_confusing_labels = n_confusing_labels
self.data, self.targets, self.biased_targets = self.build_biased_mnist()
indices = np.arange(len(self.data))
self._shuffle(indices)
self.data = self.data[indices].numpy()
self.targets = self.targets[indices]
self.biased_targets = self.biased_targets[indices]
@property
def raw_folder(self):
return os.path.join(self.root, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, 'processed')
def _shuffle(self, iteratable):
if self.random:
np.random.shuffle(iteratable)
def _make_biased_mnist(self, indices, label):
raise NotImplementedError
def _update_bias_indices(self, bias_indices, label):
if self.n_confusing_labels > 9 or self.n_confusing_labels < 1:
raise ValueError(self.n_confusing_labels)
indices = np.where((self.targets == label).numpy())[0]
self._shuffle(indices)
indices = torch.LongTensor(indices)
n_samples = len(indices)
n_correlated_samples = int(n_samples * self.data_label_correlation)
n_decorrelated_per_class = int(np.ceil((n_samples - n_correlated_samples) / (self.n_confusing_labels)))
correlated_indices = indices[:n_correlated_samples]
bias_indices[label] = torch.cat([bias_indices[label], correlated_indices])
decorrelated_indices = torch.split(indices[n_correlated_samples:], n_decorrelated_per_class)
other_labels = [_label % 10 for _label in range(label + 1, label + 1 + self.n_confusing_labels)]
self._shuffle(other_labels)
for idx, _indices in enumerate(decorrelated_indices):
_label = other_labels[idx]
bias_indices[_label] = torch.cat([bias_indices[_label], _indices])
def build_biased_mnist(self):
"""Build biased MNIST.
"""
n_labels = self.targets.max().item() + 1
bias_indices = {label: torch.LongTensor() for label in range(n_labels)}
for label in range(n_labels):
self._update_bias_indices(bias_indices, label)
data = torch.ByteTensor()
targets = torch.LongTensor()
biased_targets = []
for bias_label, indices in bias_indices.items():
_data, _targets = self._make_biased_mnist(indices, bias_label)
data = torch.cat([data, _data])
targets = torch.cat([targets, _targets])
biased_targets.extend([bias_label] * len(indices))
biased_targets = torch.LongTensor(biased_targets)
return data, targets, biased_targets
def __getitem__(self, index):
img, target = self.data[index], int(self.targets[index])
img = Image.fromarray(img.astype(np.uint8), mode='RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, int(self.biased_targets[index])
class ColourBiasedMNIST(BiasedMNIST):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False, data_label_correlation=1.0, n_confusing_labels=9):
super(ColourBiasedMNIST, self).__init__(root, train=train, transform=transform,
target_transform=target_transform,
download=download,
data_label_correlation=data_label_correlation,
n_confusing_labels=n_confusing_labels)
def _binary_to_colour(self, data, colour):
fg_data = torch.zeros_like(data)
fg_data[data != 0] = 255
fg_data[data == 0] = 0
fg_data = torch.stack([fg_data, fg_data, fg_data], dim=1)
bg_data = torch.zeros_like(data)
bg_data[data == 0] = 1
bg_data[data != 0] = 0
bg_data = torch.stack([bg_data, bg_data, bg_data], dim=3)
bg_data = bg_data * torch.ByteTensor(colour)
bg_data = bg_data.permute(0, 3, 1, 2)
data = fg_data + bg_data
return data.permute(0, 2, 3, 1)
def _make_biased_mnist(self, indices, label):
return self._binary_to_colour(self.data[indices], self.COLOUR_MAP[label]), self.targets[indices]
def get_biased_mnist_dataloader(root, batch_size, data_label_correlation,
n_confusing_labels=9, train=True, num_workers=8):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))])
dataset = ColourBiasedMNIST(root, train=train, transform=transform,
download=True, data_label_correlation=data_label_correlation,
n_confusing_labels=n_confusing_labels)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
return dataloader