-
Notifications
You must be signed in to change notification settings - Fork 319
/
pvtv2.py
512 lines (437 loc) · 20 KB
/
pvtv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DeiT in Paddle
A Paddle Implementation of Data Efficient Image Transformer (DeiT) as described in:
"Training data-efficient image transformers & distillation through attention"
- Paper Link: https://arxiv.org/abs/2012.12877
"""
import copy
import paddle
import paddle.nn as nn
from droppath import DropPath
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
This layer is used to avoid using 'if' condition in methods such as forward
"""
def forward(self, x):
return x
class DWConv(nn.Layer):
"""Depth-Wise convolution 3x3
Improve the local continuity of features.
"""
def __init__(self, dim=768):
super().__init__()
w_attr_1, b_attr_1 = self._init_weights_conv() # init for conv
self.dwconv = nn.Conv2D(in_channels=dim,
out_channels=dim,
kernel_size=3,
stride=1,
padding=1,
groups=dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
def _init_weights_conv(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierNormal(fan_in=0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x, H, W):
B, _, C = x.shape
x = x.transpose([0,2,1]).reshape([B, C, H, W])
x = self.dwconv(x)
x = x.flatten(2).transpose([0,2,1])
return x
class OverlapPatchEmbedding(nn.Layer):
"""Overlapping Patch Embedding
Apply Overlapping Patch Embedding on input images. Embeddings is implemented using a Conv2D op.
Making adjacent windows overlap by half of the area, and pad the feature map with zeros to keep
the resolution.
Attributes:
image_size: int, input image size, default: 224
patch_size: int, size of patch, default: 7
in_channels: int, input image channels, default: 3
embed_dim: int, embedding dimension, default: 768
"""
def __init__(self, image_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768):
super().__init__()
image_size = (image_size, image_size)
patch_size = (patch_size, patch_size)
self.image_size = image_size
self.patch_size = patch_size
self.H, self.W = image_size[0] // patch_size[0], image_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.patch_embed = nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
w_attr_1, b_attr_1 = self._init_weights()
self.norm = nn.LayerNorm(embed_dim, weight_attr=w_attr_1, bias_attr=b_attr_1, epsilon=1e-6)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.patch_embed(x) # [batch, embed_dim, h, w] h,w = patch_resolution
_, _, H, W = x.shape
x = x.flatten(start_axis=2, stop_axis=-1) # [batch, embed_dim, h*w] h*w = num_patches
x = x.transpose([0, 2, 1]) # [batch, h*w, embed_dim]
x = self.norm(x) # [batch, num_patches, embed_dim]
return x, H, W
class Mlp(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> dwconv -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
dwconv: Depth-Wise Convolution
act: GELU
dropout: dropout after fc1 and fc2
"""
def __init__(self, in_features, hidden_features, dropout=0.0, linear=False):
super(Mlp, self).__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(in_features,
hidden_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
self.dwconv = DWConv(hidden_features)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(hidden_features,
in_features,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.linear = linear
if self.linear:
self.relu = nn.ReLU()
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x, H, W):
x = self.fc1(x)
if self.linear:
x = self.relu(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Attention(nn.Layer):
""" Attention module
Attention module for PvT, here q, k, v are assumed the same.
The qkv mappings are stored as one single param.
Attributes:
dim: int, input dimension (channels)
num_heads: number of heads
q: a nn.Linear for q mapping
kv: a nn.Linear for kv mapping
qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
attn_dropout: dropout for attention
proj_dropout: final dropout before output
softmax: softmax op for attention
linear: bool, if True, use linear spatial reduction attn instead of spatial reduction attn
sr_ratio: the spatial reduction ratio of SRA (linear spatial reduction attention)
"""
def __init__(self,
dim,
num_heads,
qkv_bias=False,
qk_scale=None,
attention_dropout=0.,
dropout=0.,
sr_ratio=1,
linear=False):
"""init Attention"""
super(Attention, self).__init__()
self.num_heads = num_heads
self.dim = dim
self.dim_head = dim // num_heads
self.scale = qk_scale or self.dim_head ** -0.5
w_attr_1, b_attr_1 = self._init_weights()
self.q = nn.Linear(dim,
dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1 if qkv_bias else False)
w_attr_2, b_attr_2 = self._init_weights()
self.kv = nn.Linear(dim,
dim * 2,
weight_attr=w_attr_2,
bias_attr=b_attr_2 if qkv_bias else False)
self.attn_dropout = nn.Dropout(attention_dropout)
w_attr_3, b_attr_3 = self._init_weights()
self.proj = nn.Linear(dim,
dim,
weight_attr=w_attr_3,
bias_attr=b_attr_3)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
self.linear = linear
self.sr_ratio = sr_ratio
w_attr_4, b_attr_4 = self._init_weights_conv() # init for conv
w_attr_5, b_attr_5 = self._init_weights_layernorm() # init for layernorm
if not linear:
if sr_ratio > 1:
self.sr = nn.Conv2D(dim,
dim,
kernel_size=sr_ratio,
stride=sr_ratio,
weight_attr=w_attr_4,
bias_attr=b_attr_4)
self.norm = nn.LayerNorm(dim,
epsilon=1e-6,
weight_attr=w_attr_5,
bias_attr=b_attr_5)
else:
self.pool = nn.AdaptiveAvgPool2D(7)
self.sr = nn.Conv2D(dim,
dim,
kernel_size=1,
stride=1,
weight_attr=w_attr_4,
bias_attr=b_attr_4)
self.norm = nn.LayerNorm(dim,
epsilon=1e-6,
weight_attr=w_attr_5,
bias_attr=b_attr_5)
self.act = nn.GELU()
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def _init_weights_conv(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.XavierNormal(fan_in=0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape([B, N, self.num_heads, C // self.num_heads]).transpose([0, 2, 1, 3])
if not self.linear:
if self.sr_ratio > 1:
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
x_ = self.norm(x_)
kv = self.kv(x_).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4])
else:
kv = self.kv(x).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4])
else:
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x_ = self.sr(self.pool(x_)).reshape([B, C, -1]).transpose([0, 2, 1])
x_ = self.norm(x_)
x_ = self.act(x_)
kv = self.kv(x_).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
q = q * self.scale
attn = paddle.matmul(q, k, transpose_y=True)
attn = self.softmax(attn)
attn = self.attn_dropout(attn)
z = paddle.matmul(attn, v)
z = z.transpose([0, 2, 1, 3])
new_shape = z.shape[:-2] + [self.dim]
z = z.reshape(new_shape)
z = self.proj(z)
z = self.proj_dropout(z)
return z
class PvTv2Block(nn.Layer):
"""Pyramid VisionTransformerV2 block
Contains multi head efficient self attention, droppath, mlp, norm.
Attributes:
dim: int, input dimension (channels)
num_heads: int, number of attention heads
mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4.
sr_ratio: the spatial reduction ratio of SRA (linear spatial reduction attention)
qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
dropout: float, dropout for output, default: 0.
attention_dropout: float, dropout of attention, default: 0.
drop_path: float, drop path rate, default: 0.
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, dropout=0.,
attention_dropout=0., drop_path=0., sr_ratio=1, linear=False):
super(PvTv2Block, self).__init__()
w_attr_1, b_attr_1 = self._init_weights_layernorm() # init for layernorm
self.norm1 = nn.LayerNorm(dim, epsilon=1e-6, weight_attr=w_attr_1, bias_attr=b_attr_1)
self.attn = Attention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attention_dropout=attention_dropout,
dropout=dropout,
sr_ratio=sr_ratio,
linear=linear)
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
w_attr_2, b_attr_2 = self._init_weights_layernorm() # init for layernorm
self.norm2 = nn.LayerNorm(dim, epsilon=1e-6, weight_attr=w_attr_2, bias_attr=b_attr_2)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim*mlp_ratio),
dropout=dropout,
linear=linear)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class PyramidVisionTransformerV2(nn.Layer):
"""PyramidVisionTransformerV2 class
Attributes:
patch_size: int, size of patch
image_size: int, size of image
num_classes: int, num of image classes
in_channels: int, channel of input image
num_heads: int, num of heads in attention module
num_stages: int, num of stages contains OverlapPatch embedding and PvTv2 blocks
depths: list of int, num of PvTv2 blocks in each stage
mlp_ratio: float, hidden dimension of mlp layer is mlp_ratio * mlp input dim
sr_ratio: the spatial reduction ratio of SRA (linear spatial reduction attention)
qkv_bias: bool, if True, set qkv layers have bias enabled
qk_scale: float, scale factor for qk.
embed_dims: list of int, output dimension of patch embedding
dropout: float, dropout rate for linear layer
attention_dropout: float, dropout rate for attention
drop_path: float, drop path rate, default: 0.
linear: bool, if True, use linear spatial reduction attn instead of spatial reduction attn
patch_embedding: PatchEmbedding, patch embedding instance
norm: nn.LayerNorm, norm layer applied after transformer
fc: nn.Linear, classifier op.
"""
def __init__(self,
image_size=224,
patch_size=4,
embed_dims=[32, 64, 160, 256],
num_classes=1000,
in_channels=3,
num_heads=[1, 2, 5, 8],
depths=[2, 2, 2, 2],
mlp_ratio=[8, 8, 4, 4],
sr_ratio=[8, 4, 2, 1],
qkv_bias=True,
qk_scale=None,
dropout=0.,
attention_dropout=0.,
drop_path=0.,
linear=False):
super(PyramidVisionTransformerV2, self).__init__()
self.patch_size = patch_size
self.image_size = image_size
self.num_classes = num_classes
self.in_channels = in_channels
self.num_heads = num_heads
self.depths = depths
self.num_stages = len(self.depths)
self.mlp_ratio = mlp_ratio
self.sr_ratio = sr_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.embed_dims = embed_dims
self.dropout = dropout
self.attention_dropout = attention_dropout
self.drop_path = drop_path
self.linear = linear
depth_decay = [x.item() for x in paddle.linspace(0, self.drop_path, sum(self.depths))]
cur = 0
for i in range(self.num_stages):
patch_embedding = OverlapPatchEmbedding(
image_size=self.image_size if i == 0 else self.image_size // (2 ** (i + 1)),
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=self.in_channels if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i])
block = nn.LayerList([copy.deepcopy(PvTv2Block(
dim=self.embed_dims[i],
num_heads=self.num_heads[i],
mlp_ratio=self.mlp_ratio[i],
qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
drop_path=depth_decay[cur + j],
sr_ratio=self.sr_ratio[i],
linear=self.linear)) for j in range(self.depths[i])])
w_attr_1, b_attr_1 = self._init_weights_layernorm() # init for layernorm
norm = nn.LayerNorm(self.embed_dims[i],
epsilon=1e-6,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
cur += self.depths[i]
setattr(self, f"patch_embedding{i + 1}", patch_embedding)
setattr(self, f"block{i + 1}", block)
setattr(self, f"norm{i + 1}", norm)
# classification head
w_attr_2, b_attr_2 = self._init_weights() # init for linear
self.head = nn.Linear(self.embed_dims[3],
self.num_classes,
weight_attr=w_attr_2,
bias_attr=b_attr_2) if self.num_classes > 0 else Identity()
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def freeze_patch_embedding(self):
self.patch_embedding1.requires_grad = False
def forward_features(self, x):
B = x.shape[0]
for i in range(self.num_stages):
patch_embedding = getattr(self, f"patch_embedding{i + 1}")
block = getattr(self, f"block{i + 1}")
norm = getattr(self, f"norm{i + 1}")
x, H, W = patch_embedding(x)
for idx, blk in enumerate(block):
x = blk(x, H, W)
x = norm(x)
if i != self.num_stages - 1:
x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
return x.mean(axis=1)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def build_pvtv2(config):
"""build pvtv2 model from config"""
model = PyramidVisionTransformerV2(
image_size=config.DATA.IMAGE_SIZE,
patch_size=config.MODEL.PATCH_SIZE,
embed_dims=config.MODEL.EMBED_DIM,
num_classes=config.MODEL.NUM_CLASSES,
in_channels=config.DATA.IMAGE_CHANNELS,
num_heads=config.MODEL.NUM_HEADS,
depths=config.MODEL.STAGE_DEPTH,
mlp_ratio=config.MODEL.MLP_RATIO,
sr_ratio=config.MODEL.SR_RATIO,
qkv_bias=config.MODEL.QKV_BIAS,
qk_scale=config.MODEL.QK_SCALE,
dropout=config.MODEL.DROPOUT,
attention_dropout=config.MODEL.ATTENTION_DROPOUT,
drop_path=config.MODEL.DROPPATH,
linear=config.MODEL.LINEAR)
return model