-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
209 lines (141 loc) · 6.37 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import torch
import torch.nn as nn
import torch.nn.functional as F
# def get_vgg19_FeatureMap(vgg_model, input_255, layer_index):
# vgg_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape((1,3,1,1))
# if torch.cuda.is_available():
# vgg_mean = vgg_mean.cuda()
# vgg_input = input_255-vgg_mean
# #x = vgg_model.features[0](vgg_input)
# #FeatureMap_list.append(x)
# for i in range(0,layer_index+1):
# if i == 0:
# x = vgg_model.features[0](vgg_input)
# else:
# x = vgg_model.features[i](x)
# return x
def l_num_loss(img1, img2, l_num=1):
return torch.mean(torch.abs((img1 - img2)**l_num))
def boundary_extraction(mask):
ones = torch.ones_like(mask)
zeros = torch.zeros_like(mask)
#define kernel
in_channel = 1
out_channel = 1
kernel = [[1, 1, 1],
[1, 1, 1],
[1, 1, 1]]
kernel = torch.FloatTensor(kernel).expand(out_channel,in_channel,3,3)
if torch.cuda.is_available():
kernel = kernel.cuda()
ones = ones.cuda()
zeros = zeros.cuda()
weight = nn.Parameter(data=kernel, requires_grad=False)
#dilation
x = F.conv2d(1-mask,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
x = F.conv2d(x,weight,stride=1,padding=1)
x = torch.where(x < 1, zeros, ones)
return x*mask
def cal_boundary_term(inpu1_tesnor, inpu2_tesnor, mask1_tesnor, mask2_tesnor, stitched_image):
boundary_mask1 = mask1_tesnor * boundary_extraction(mask2_tesnor)
boundary_mask2 = mask2_tesnor * boundary_extraction(mask1_tesnor)
loss1 = l_num_loss(inpu1_tesnor*boundary_mask1, stitched_image*boundary_mask1, 1)
loss2 = l_num_loss(inpu2_tesnor*boundary_mask2, stitched_image*boundary_mask2, 1)
return loss1+loss2, boundary_mask1
def cal_smooth_term_stitch(stitched_image, learned_mask1):
delta = 1
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
dh_diff_img = torch.abs(stitched_image[:,:,0:-1*delta,:] - stitched_image[:,:,delta:,:])
dw_diff_img = torch.abs(stitched_image[:,:,:,0:-1*delta] - stitched_image[:,:,:,delta:])
dh_pixel = dh_mask * dh_diff_img
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
def cal_smooth_term_diff(img1, img2, learned_mask1, overlap):
diff_feature = torch.abs(img1-img2)**2 * overlap
delta = 1
dh_mask = torch.abs(learned_mask1[:,:,0:-1*delta,:] - learned_mask1[:,:,delta:,:])
dw_mask = torch.abs(learned_mask1[:,:,:,0:-1*delta] - learned_mask1[:,:,:,delta:])
dh_diff_img = torch.abs(diff_feature[:,:,0:-1*delta,:] + diff_feature[:,:,delta:,:])
dw_diff_img = torch.abs(diff_feature[:,:,:,0:-1*delta] + diff_feature[:,:,:,delta:])
dh_pixel = dh_mask * dh_diff_img
dw_pixel = dw_mask * dw_diff_img
loss = torch.mean(dh_pixel) + torch.mean(dw_pixel)
return loss
# dh_zeros = torch.zeros_like(dh_pixel)
# dw_zeros = torch.zeros_like(dw_pixel)
# if torch.cuda.is_available():
# dh_zeros = dh_zeros.cuda()
# dw_zeros = dw_zeros.cuda()
# loss = l_num_loss(dh_pixel, dh_zeros, 1) + l_num_loss(dw_pixel, dw_zeros, 1)
# return loss, dh_pixel
"""
warp1\input1\mask1是ref
warp2\input2\mask2是target
TODO:基于salient object detection的语义loss
鼓励object的mask都来自M_cr
M_1 = M_object & M_cr
M_2 = M_object & M_ct
loss = M_1*M_cr + M_2*M_cr
return loss
"""
#================================================================
import torch
def object_completeness(object_mask1, learned_mask1, learned_mask2):
# 计算object_mask和 learned_mask1 的重叠部分
M_1 = object_mask1 * learned_mask1
# 计算object_mask和 learned_mask2 的重叠部分
M_2 = object_mask1 * learned_mask2
# 计算 M_1 和 M_2 的总面积(总和)
area_M_1 = torch.sum(M_1)
area_M_2 = torch.sum(M_2)
# 判断哪一个掩码的面积大
if area_M_1 > area_M_2:
# 如果 M_1 的面积大,则鼓励 learned_mask1 完全覆盖 object_mask
completeness_loss = torch.sum((object_mask1 - M_1) ** 2) / torch.numel(object_mask1)
else:
# 如果 M_2 的面积大,则鼓励 learned_mask2 完全覆盖 object_mask
completeness_loss = torch.sum((object_mask1 - M_2) ** 2) / torch.numel(object_mask1)
# 排他性损失: 最小化object_mask与learned_mask2的重叠
exclusivity_loss = torch.sum(M_2 ** 2)
# 总损失: 考虑到完整性和排他性
loss = completeness_loss + exclusivity_loss
return loss
# def object_completeness(object_mask1, learned_mask1, learned_mask2):
# # 计算object_mask1 和 learned_mask1 的重叠部分
# M_1 = object_mask1 * learned_mask1
# # 计算object_mask1 和 learned_mask2 的重叠部分
# M_2 = object_mask1 * learned_mask2
# # 完整性损失: 鼓励object_mask1完全在learned_mask1内
# # completeness_loss = torch.sum((object_mask1 - M_1) ** 2)
# completeness_loss = torch.sum((object_mask1 - M_1) ** 2) / torch.numel(object_mask1)
# # 排他性损失: 最小化object_mask1与learned_mask2的重叠
# exclusivity_loss = torch.sum(M_2 ** 2)
# # 总损失: 考虑到完整性和排他性
# loss = completeness_loss + exclusivity_loss
# return loss
#================================================================
"""
def object_completeness(object_mask1, learned_mask1, learned_mask2):
# 计算object_mask1 和 learned_mask1 的重叠部分
M_1 = object_mask1 * learned_mask1
# 计算object_mask1 和 learned_mask2 的重叠部分
M_2 = object_mask1 * learned_mask2
# 完整性损失: 鼓励object_mask1完全在learned_mask1内
completeness_loss = torch.sum((object_mask1 - M_1) ** 2)
# 总损失: 考虑到完整性和排他性
loss = completeness_loss
return loss
"""