-
Notifications
You must be signed in to change notification settings - Fork 0
/
filter_utils.py
66 lines (50 loc) · 1.97 KB
/
filter_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import numpy as np
def _get_conv_filter(filter_name):
if filter_name == 'sobel1':
return [[[[1,0,-1],[2,0,-2],[1,0,-1]]]]
if filter_name == 'sobel2':
return [[[[1,2,1],[0,0,0],[-1,-2,-1]]]]
if filter_name == 'roberts1':
return [[[[1,0],[0,-1]]]]
if filter_name == 'roberts2':
return [[[[0,1],[-1,0]]]]
if filter_name == 'prewitt1':
return [[[[1,0,-1],[1,0,-1],[1,0,-1]]]]
if filter_name == 'prewitt2':
return [[[[1,1,1],[0,0,0],[-1,-1,-1]]]]
if filter_name == 'laplace':
return [[[[0,-1,0],[-1,4,-1],[0,-1,0]]]]
if filter_name == 'highpass':
return [[[[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]]]
if filter_name == 'gauss':
return [[[[1,4,6,4,1],[4,16,24,16,4],[6,24,36,24,6],[4,16,24,16,4],[1,4,6,4,1]]]]
if filter_name == 'mean':
return np.tile(1/9,(1,1,3,3))
def _get_morph_filter(filter_name):
if filter_name == 'erosion':
return [[[[1,1,1],[1,1,1],[1,1,1]]]]
if filter_name == 'dilation':
return [[[[0,1,0],[1,1,1],[0,1,0]]]]
if filter_name == 'full':
return [[[[1,1,1],[1,1,1],[1,1,1]]]]
def _get_fft_filter(filter_name, threshold, img_shape):
x, y = img_shape[-2]//2, img_shape[-1]//2
idx = np.indices(img_shape)
idx_comb = zip(idx[0].flatten(), idx[1].flatten())
filter = np.array([np.sqrt(np.power(i-x,2)+np.power(j-y,2))
for i, j in idx_comb]).reshape(img_shape)
if filter_name == 'bandpass':
filter[filter <= threshold] = 1
filter[filter > threshold] = 0
return filter
if filter_name == 'gausslowpass':
return 1/(1+np.power((filter/threshold),2))
if filter_name == 'gaussbandpass':
return np.exp(-np.power((np.power(filter,2) - np.power(threshold,2)) / (30 * filter),2))
def normalize_image(img, gray_levels):
img += torch.abs(torch.min(img))
img /= torch.max(img)
img *= gray_levels
img = img.type(torch.int)
return img