-
Notifications
You must be signed in to change notification settings - Fork 9
/
crf.py
37 lines (28 loc) · 985 Bytes
/
crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Authors: Wouter Van Gansbeke & Simon Vandenhende
# Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
import numpy as np
import pydensecrf.densecrf as dcrf
import pydensecrf.utils as utils
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as VF
POS_W = 3
POS_XY_STD = 1
Bi_W = 4
Bi_XY_STD = 67
Bi_RGB_STD = 3
BGR_MEAN = np.array([104.008, 116.669, 122.675])
def rgb_dense_crf(image, output_probs, max_iter=10):
image = np.ascontiguousarray(image)
c = output_probs.shape[0]
h = output_probs.shape[1]
w = output_probs.shape[2]
U = utils.unary_from_softmax(output_probs)
U = np.ascontiguousarray(U)
d = dcrf.DenseCRF2D(w, h, c)
d.setUnaryEnergy(U)
d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W)
d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=image, compat=Bi_W)
Q = d.inference(max_iter)
Q = np.array(Q).reshape((c, h, w))
return Q