-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_shifts.py
178 lines (135 loc) · 4.84 KB
/
data_shifts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import numpy as np
from typing import Tuple
from torchvision.transforms import GaussianBlur
from keras.preprocessing.image import ImageDataGenerator
registered_shifts = {}
def register_shift(shift_name: str):
def decorator(func):
registered_shifts[shift_name] = func
return func
return decorator
@register_shift("gaussian_noise")
def gaussian_noise(x: np.ndarray, y: np.ndarray, sigma: float):
# Generate noise
noise = np.random.normal(loc=0.0, scale=sigma, size=x.shape)
# Apply noise
x_noise = x + noise
x_noise = np.clip(x_noise, 0.0, 1.0)
return x_noise, y
@register_shift("salt_pepper_noise")
def salt_pepper_noise(x: np.ndarray, y: np.ndarray, noise_ratio: float):
# Generate noise
noise = np.random.choice(
[-1, 0, 1], size=x.shape, p=[noise_ratio / 2, 1 - noise_ratio, noise_ratio / 2]
)
# Apply noise
x_noise = x.copy()
x_noise = np.where(noise == -1, 0.0, x_noise)
x_noise = np.where(noise == 1, 1.0, x_noise)
return x_noise, y
@register_shift("gaussian_blur")
def gaussian_blur(
x: np.ndarray, y: np.ndarray, sigma: float, kernel_size: Tuple[int, int] = (3, 3)
):
"""
Inpsired by MAGDIFF Code
https://github.com/hensel-f/MAGDiff_experiments/blob/main/utils/utils_shift_functions.py
"""
# Convert to torch
x = torch.from_numpy(x).float()
# Check dimensions
if x.ndim == 3:
x = x.unsqueeze(1)
assert x.ndim == 4
# Instantiate Gaussian Blur
gaussian_blur = GaussianBlur(kernel_size=kernel_size, sigma=sigma)
# Apply Gaussian Blur
x_blur = np.stack([gaussian_blur(img).numpy() for img in x])
return x_blur, y
@register_shift("image_transform")
def image_transform(x: np.ndarray, y: np.ndarray, **transform_parameters):
"""
Inpsired by MAGDIFF Code
https://github.com/hensel-f/MAGDiff_experiments/blob/main/utils/utils_shift_functions.py
Example for `transform_parameters`:
transform_parameters={
"rotation_range": 100 * sigma,
"zoom_range": sigma,
"shear_range": sigma,
"vertical_flip": sigma > 0.5,
"width_shift_range": sigma / 2.,
"height_shift_range": sigma / 2.,
}
"""
# Convert to torch
x = torch.from_numpy(x).float()
# Check dimensions
if x.ndim == 3:
x = x.unsqueeze(1)
assert x.ndim == 4
# Instantiate ImageDataGenerator
transform_parameters.update(
{
"fill_mode": "nearest",
"data_format": "channels_first",
}
)
image_data_generator = ImageDataGenerator(**transform_parameters)
# Apply ImageDataGenerator
x_transformed = image_data_generator.flow(x, batch_size=len(x), shuffle=False)
x_transformed = next(x_transformed)
return x_transformed, y
@register_shift("uniform_noise")
def uniform_noise(x: np.ndarray, y: np.ndarray, sigma: float):
# Generate noise
noise = np.random.uniform(low=-sigma, high=sigma, size=x.shape)
# Apply noise
x_noise = x + noise
x_noise = np.clip(x_noise, 0.0, 1.0)
return x_noise, y
@register_shift("pixel_shuffle")
def pixel_shuffle(x: np.ndarray, y: np.ndarray, kernel_size: int):
# Check dimensions
if x.ndim == 3:
x = x[:, np.newaxis, :, :]
assert x.ndim == 4
# Get dimensions
_, channels, height, width = x.shape
# Iterate over patches
x_shuffle = []
for img in x:
img_shuffle = img.copy()
for x_start in range(0, height - kernel_size):
for y_start in range(0, width - kernel_size):
# Get end indices
x_end = x_start + kernel_size
y_end = y_start + kernel_size
patch = img[:, x_start:x_end, y_start:y_end]
patch_height, patch_width = patch.shape[1:]
num_patch_pixels = patch_height * patch_width
patch = patch.reshape(channels, -1)
permutation = np.random.permutation(num_patch_pixels)
patch = patch[:, permutation]
patch = patch.reshape(channels, patch_height, patch_width)
img_shuffle[:, x_start:x_end, y_start:y_end] = patch
x_shuffle.append(img_shuffle)
x_shuffle = np.stack(x_shuffle)
return x_shuffle, y
@register_shift("pixel_dropout")
def pixel_dropout(x: np.ndarray, y: np.ndarray, p: float):
# Check dimensions
if x.ndim == 3:
x = x[:, np.newaxis, :, :]
assert x.ndim == 4
# Make dropout mask
assert 0.0 <= p <= 1.0
num_images, channels, height, width = x.shape
dropout_mask = np.random.binomial(n=1, p=p, size=(num_images, height, width))
dropout_mask = np.repeat(dropout_mask[:, np.newaxis, :, :], channels, axis=1)
dropout_mask = dropout_mask.astype(bool)
# Apply dropout mask
x_dropout = x.copy()
x_dropout[dropout_mask] = 0.0
# Return
return x_dropout, y