-
Notifications
You must be signed in to change notification settings - Fork 1
/
yolo_asff.py
120 lines (100 loc) · 4.8 KB
/
yolo_asff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
# Standard convolution
# ch_in, ch_out, kernel, stride, padding, groups
def __init__(self, in_channels, out_channels, kernel=1, stride=1, padding=None, groups=1, act=True):
super(Conv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel, stride, autopad(kernel, padding), groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class ASFF(nn.Module):
def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True):
"""
multiplier should be 1, 0.5
which means, the channel of ASFF can be
512, 256, 128 -> multiplier=0.5
1024, 512, 256 -> multiplier=1
For even smaller, you need change code manually.
"""
super(ASFF, self).__init__()
self.level = level
self.dim = [int(1024 * multiplier), int(512 * multiplier),
int(256 * multiplier)]
# print(self.dim)
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = Conv(int(512 * multiplier), self.inter_dim, 3, 2)
self.stride_level_2 = Conv(int(256 * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(1024 * multiplier), 3, 1)
elif level == 1:
self.compress_level_0 = Conv(int(1024 * multiplier), self.inter_dim, 1, 1)
self.stride_level_2 = Conv(int(256 * multiplier), self.inter_dim, 3, 2)
self.expand = Conv(self.inter_dim, int(512 * multiplier), 3, 1)
elif level == 2:
self.compress_level_0 = Conv(int(1024 * multiplier), self.inter_dim, 1, 1)
self.compress_level_1 = Conv(int(512 * multiplier), self.inter_dim, 1, 1)
self.expand = Conv(self.inter_dim, int(256 * multiplier), 3, 1)
# when adding rfb, we use half number of channels to save memory
compress_c = 8 if rfb else 16
self.weight_level_0 = Conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = Conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = Conv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = Conv(compress_c * 3, 3, 1, 1)
self.vis = vis
def forward(self, x): # l,m,s
"""
#
256, 512, 1024
from small -> large
"""
# max feature
global level_0_resized, level_1_resized, level_2_resized
x_level_0 = x[2]
# mid feature
x_level_1 = x[1]
# min feature
x_level_2 = x[0]
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
x_level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(x_level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
level_1_resized * levels_weight[:, 1:2, :, :] + \
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out