Skip to content

Commit

Permalink
Fix: Update UnetTR architecture based on https://github.com/Project-M…
Browse files Browse the repository at this point in the history
  • Loading branch information
black0017 committed Jul 22, 2021
1 parent b39cde2 commit 05b3d79
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 50 deletions.
58 changes: 28 additions & 30 deletions self_attention_cv/UnetTr/UnetTr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class TransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers):
def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers, dim_linear_block):
super().__init__()
self.layer = nn.ModuleList()
self.extract_layers = extract_layers
Expand All @@ -17,7 +17,7 @@ def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers):
self.block_list = nn.ModuleList()
for _ in range(num_layers):
self.block_list.append(TransformerBlock(dim=embed_dim, heads=num_heads,
dim_linear_block=1024, dropout=dropout, prenorm=True))
dim_linear_block=dim_linear_block, dropout=dropout, prenorm=True))

def forward(self, x):
extract_layers = []
Expand All @@ -28,13 +28,17 @@ def forward(self, x):

return extract_layers


# based on https://arxiv.org/abs/2103.10504
# implementation is influenced by practical details missing in the paper that can be found
# https://github.com/Project-MONAI/MONAI/blob/027947bf91ff0dfac94f472ed1855cd49e3feb8d/monai/networks/nets/unetr.py
class UNETR(nn.Module):
def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3,
embed_dim=768, patch_size=16, num_heads=12, dropout=0.1,
num_layers=12, ext_layers=[3, 6, 9, 12], version='light'):
embed_dim=768, patch_size=16, num_heads=12, dropout=0.0,
ext_layers=[3, 6, 9, 12], norm='instance',
base_filters=16,
dim_linear_block=3072):
"""
Args:
img_shape: volume shape, provided as a tuple
input_dim: input modalities/channels
Expand All @@ -43,63 +47,57 @@ def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3,
patch_size: the non-overlapping patches to be created
num_heads: for the transformer encoder
dropout: percentage for dropout
num_layers: static to the architecture. cannot be changed with the current architecture.
ext_layers: transformer layers to use their output
version: 'light' saves some parameters in the decoding part
norm: batch or instance norm for the conv blocks
"""
super().__init__()
self.num_layers = 12
self.input_dim = input_dim
self.output_dim = output_dim
self.embed_dim = embed_dim
self.img_shape = img_shape
self.patch_size = patch_size
self.num_heads = num_heads
self.dropout = dropout
self.num_layers = num_layers
self.ext_layers = ext_layers
self.patch_dim = [int(x / patch_size) for x in img_shape]
self.base_filters = 64
self.prelast_filters = 32

# cheap way to reduce the number of parameters in the decoding part.
self.yellow_conv_channels = [256, 128, 64] if version == 'light' else [512, 256, 128]
self.norm = nn.BatchNorm3d if norm == 'batch' else nn.InstanceNorm3d

self.embed = Embeddings3D(input_dim=input_dim, embed_dim=embed_dim,
cube_size=img_shape, patch_size=patch_size, dropout=dropout)

self.transformer = TransformerEncoder(embed_dim, num_heads, num_layers, dropout, ext_layers)
self.transformer = TransformerEncoder(embed_dim, num_heads,
self.num_layers, dropout, ext_layers,
dim_linear_block=dim_linear_block)

self.init_conv = Conv3DBlock(input_dim, self.base_filters, double=True)
self.init_conv = Conv3DBlock(input_dim, base_filters, double=True, norm=self.norm)

# blue blocks in Fig.1
self.z3_blue_conv = nn.Sequential(
BlueBlock(in_planes=embed_dim, out_planes=512),
BlueBlock(in_planes=512, out_planes=256),
BlueBlock(in_planes=256, out_planes=128))
self.z3_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=base_filters * 2, layers=3)

self.z6_blue_conv = nn.Sequential(
BlueBlock(in_planes=embed_dim, out_planes=512),
BlueBlock(in_planes=512, out_planes=256))
self.z6_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=base_filters * 4, layers=2)

self.z9_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=512)
self.z9_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=base_filters * 8, layers=1)

# Green blocks in Fig.1
self.z12_deconv = TranspConv3DBlock(embed_dim, 512)
self.z12_deconv = TranspConv3DBlock(embed_dim, base_filters * 8)

self.z9_deconv = TranspConv3DBlock(self.yellow_conv_channels[0], 256)
self.z6_deconv = TranspConv3DBlock(self.yellow_conv_channels[1], 128)
self.z3_deconv = TranspConv3DBlock(self.yellow_conv_channels[2], 64)
self.z9_deconv = TranspConv3DBlock(base_filters * 8, base_filters * 4)
self.z6_deconv = TranspConv3DBlock(base_filters * 4, base_filters * 2)
self.z3_deconv = TranspConv3DBlock(base_filters * 2, base_filters)

# Yellow blocks in Fig.1
self.z9_conv = Conv3DBlock(1024, self.yellow_conv_channels[0], double=True)
self.z6_conv = Conv3DBlock(512, self.yellow_conv_channels[1], double=True)
self.z3_conv = Conv3DBlock(256, self.yellow_conv_channels[2], double=True)
self.z9_conv = Conv3DBlock(base_filters * 8 * 2, base_filters * 8, double=True)
self.z6_conv = Conv3DBlock(base_filters * 4 * 2, base_filters * 4, double=True)
self.z3_conv = Conv3DBlock(base_filters * 2 * 2, base_filters * 2, double=True)
# out convolutions
self.out_conv = nn.Sequential(
# last yellow conv block
Conv3DBlock(128, self.prelast_filters, double=True),
Conv3DBlock(base_filters * 2, base_filters, double=True),
# grey block, final classification layer
Conv3DBlock(self.prelast_filters, output_dim, kernel_size=1, double=False))
Conv3DBlock(base_filters, output_dim, kernel_size=1, double=False))

def forward(self, x):
transf_input = self.embed(x)
Expand Down
57 changes: 40 additions & 17 deletions self_attention_cv/UnetTr/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,42 @@

# yellow block in Fig.1
class Conv3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, double=True):
def __init__(self, in_planes, out_planes, kernel_size=3, double=True, norm=nn.BatchNorm3d, skip=True):
super().__init__()
self.skip = skip
self.downsample = in_planes != out_planes
self.final_activation = nn.LeakyReLU(negative_slope=0.01,inplace=True)
padding = (kernel_size - 1) // 2
if double:
self.conv_block = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(inplace=True),
padding=padding),
norm(out_planes),
nn.LeakyReLU(negative_slope=0.01,inplace=True),
nn.Conv3d(out_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(inplace=True)
)
padding=padding),
norm(out_planes))
else:
self.conv_block = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(inplace=True))
padding=padding),
norm(out_planes))

if self.skip and self.downsample:
self.conv_down = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1,
padding=0),
norm(out_planes))

def forward(self, x):
return self.conv_block(x)
y = self.conv_block(x)
if self.skip:
res = x
if self.downsample:
res = self.conv_down(res)
y = y + res
return self.final_activation(y)


# green block in Fig.1
class TranspConv3DBlock(nn.Module):
Expand All @@ -33,13 +47,22 @@ def __init__(self, in_planes, out_planes):
self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)

def forward(self, x):
return self.block(x)
y = self.block(x)
return y


# blue box in Fig.1
class BlueBlock(nn.Module):
def __init__(self, in_planes, out_planes):
def __init__(self, in_planes, out_planes, layers=1):
super().__init__()
self.block = nn.Sequential(TranspConv3DBlock(in_planes, out_planes),
Conv3DBlock(out_planes, out_planes,double=False))
self.blocks = nn.ModuleList([TranspConv3DBlock(in_planes, out_planes),
Conv3DBlock(out_planes, out_planes, double=False)])
if int(layers)>=2:
for _ in range(int(layers) - 1):
self.blocks.append(TranspConv3DBlock(out_planes, out_planes))
self.blocks.append(Conv3DBlock(out_planes, out_planes, double=False))

def forward(self, x):
return self.block(x)
for blk in self.blocks:
x = blk(x)
return x
2 changes: 1 addition & 1 deletion self_attention_cv/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

__version__ = "1.2.0"
__version__ = "1.2.1"

msg = "Self_attention_cv is only compatible with Python 3.0 and newer."

Expand Down
4 changes: 2 additions & 2 deletions tests/test_unetTR.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

def test_unettr():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1, version='a').to(device)
model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1).to(device)
a = torch.rand(1, 1, 64, 64, 64).to(device)
assert model(a).shape == (1,1,64,64,64)
del model
model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1, version='light').to(device)
model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1).to(device)
assert model(a).shape == (1, 1, 64, 64, 64)

test_unettr()

0 comments on commit 05b3d79

Please sign in to comment.