diff --git a/self_attention_cv/UnetTr/UnetTr.py b/self_attention_cv/UnetTr/UnetTr.py index 09045b5..a81a7fd 100644 --- a/self_attention_cv/UnetTr/UnetTr.py +++ b/self_attention_cv/UnetTr/UnetTr.py @@ -8,7 +8,7 @@ class TransformerEncoder(nn.Module): - def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers): + def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers, dim_linear_block): super().__init__() self.layer = nn.ModuleList() self.extract_layers = extract_layers @@ -17,7 +17,7 @@ def __init__(self, embed_dim, num_heads, num_layers, dropout, extract_layers): self.block_list = nn.ModuleList() for _ in range(num_layers): self.block_list.append(TransformerBlock(dim=embed_dim, heads=num_heads, - dim_linear_block=1024, dropout=dropout, prenorm=True)) + dim_linear_block=dim_linear_block, dropout=dropout, prenorm=True)) def forward(self, x): extract_layers = [] @@ -28,13 +28,17 @@ def forward(self, x): return extract_layers + # based on https://arxiv.org/abs/2103.10504 +# implementation is influenced by practical details missing in the paper that can be found +# https://github.com/Project-MONAI/MONAI/blob/027947bf91ff0dfac94f472ed1855cd49e3feb8d/monai/networks/nets/unetr.py class UNETR(nn.Module): def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3, - embed_dim=768, patch_size=16, num_heads=12, dropout=0.1, - num_layers=12, ext_layers=[3, 6, 9, 12], version='light'): + embed_dim=768, patch_size=16, num_heads=12, dropout=0.0, + ext_layers=[3, 6, 9, 12], norm='instance', + base_filters=16, + dim_linear_block=3072): """ - Args: img_shape: volume shape, provided as a tuple input_dim: input modalities/channels @@ -43,11 +47,12 @@ def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3, patch_size: the non-overlapping patches to be created num_heads: for the transformer encoder dropout: percentage for dropout - num_layers: static to the architecture. cannot be changed with the current architecture. ext_layers: transformer layers to use their output version: 'light' saves some parameters in the decoding part + norm: batch or instance norm for the conv blocks """ super().__init__() + self.num_layers = 12 self.input_dim = input_dim self.output_dim = output_dim self.embed_dim = embed_dim @@ -55,51 +60,44 @@ def __init__(self, img_shape=(128, 128, 128), input_dim=4, output_dim=3, self.patch_size = patch_size self.num_heads = num_heads self.dropout = dropout - self.num_layers = num_layers self.ext_layers = ext_layers self.patch_dim = [int(x / patch_size) for x in img_shape] - self.base_filters = 64 - self.prelast_filters = 32 - # cheap way to reduce the number of parameters in the decoding part. - self.yellow_conv_channels = [256, 128, 64] if version == 'light' else [512, 256, 128] + self.norm = nn.BatchNorm3d if norm == 'batch' else nn.InstanceNorm3d self.embed = Embeddings3D(input_dim=input_dim, embed_dim=embed_dim, cube_size=img_shape, patch_size=patch_size, dropout=dropout) - self.transformer = TransformerEncoder(embed_dim, num_heads, num_layers, dropout, ext_layers) + self.transformer = TransformerEncoder(embed_dim, num_heads, + self.num_layers, dropout, ext_layers, + dim_linear_block=dim_linear_block) - self.init_conv = Conv3DBlock(input_dim, self.base_filters, double=True) + self.init_conv = Conv3DBlock(input_dim, base_filters, double=True, norm=self.norm) # blue blocks in Fig.1 - self.z3_blue_conv = nn.Sequential( - BlueBlock(in_planes=embed_dim, out_planes=512), - BlueBlock(in_planes=512, out_planes=256), - BlueBlock(in_planes=256, out_planes=128)) + self.z3_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=base_filters * 2, layers=3) - self.z6_blue_conv = nn.Sequential( - BlueBlock(in_planes=embed_dim, out_planes=512), - BlueBlock(in_planes=512, out_planes=256)) + self.z6_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=base_filters * 4, layers=2) - self.z9_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=512) + self.z9_blue_conv = BlueBlock(in_planes=embed_dim, out_planes=base_filters * 8, layers=1) # Green blocks in Fig.1 - self.z12_deconv = TranspConv3DBlock(embed_dim, 512) + self.z12_deconv = TranspConv3DBlock(embed_dim, base_filters * 8) - self.z9_deconv = TranspConv3DBlock(self.yellow_conv_channels[0], 256) - self.z6_deconv = TranspConv3DBlock(self.yellow_conv_channels[1], 128) - self.z3_deconv = TranspConv3DBlock(self.yellow_conv_channels[2], 64) + self.z9_deconv = TranspConv3DBlock(base_filters * 8, base_filters * 4) + self.z6_deconv = TranspConv3DBlock(base_filters * 4, base_filters * 2) + self.z3_deconv = TranspConv3DBlock(base_filters * 2, base_filters) # Yellow blocks in Fig.1 - self.z9_conv = Conv3DBlock(1024, self.yellow_conv_channels[0], double=True) - self.z6_conv = Conv3DBlock(512, self.yellow_conv_channels[1], double=True) - self.z3_conv = Conv3DBlock(256, self.yellow_conv_channels[2], double=True) + self.z9_conv = Conv3DBlock(base_filters * 8 * 2, base_filters * 8, double=True) + self.z6_conv = Conv3DBlock(base_filters * 4 * 2, base_filters * 4, double=True) + self.z3_conv = Conv3DBlock(base_filters * 2 * 2, base_filters * 2, double=True) # out convolutions self.out_conv = nn.Sequential( # last yellow conv block - Conv3DBlock(128, self.prelast_filters, double=True), + Conv3DBlock(base_filters * 2, base_filters, double=True), # grey block, final classification layer - Conv3DBlock(self.prelast_filters, output_dim, kernel_size=1, double=False)) + Conv3DBlock(base_filters, output_dim, kernel_size=1, double=False)) def forward(self, x): transf_input = self.embed(x) diff --git a/self_attention_cv/UnetTr/modules.py b/self_attention_cv/UnetTr/modules.py index 8f04c8b..4a86d28 100644 --- a/self_attention_cv/UnetTr/modules.py +++ b/self_attention_cv/UnetTr/modules.py @@ -3,28 +3,42 @@ # yellow block in Fig.1 class Conv3DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size=3, double=True): + def __init__(self, in_planes, out_planes, kernel_size=3, double=True, norm=nn.BatchNorm3d, skip=True): super().__init__() + self.skip = skip + self.downsample = in_planes != out_planes + self.final_activation = nn.LeakyReLU(negative_slope=0.01,inplace=True) + padding = (kernel_size - 1) // 2 if double: self.conv_block = nn.Sequential( nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, - padding=((kernel_size - 1) // 2)), - nn.BatchNorm3d(out_planes), - nn.ReLU(inplace=True), + padding=padding), + norm(out_planes), + nn.LeakyReLU(negative_slope=0.01,inplace=True), nn.Conv3d(out_planes, out_planes, kernel_size=kernel_size, stride=1, - padding=((kernel_size - 1) // 2)), - nn.BatchNorm3d(out_planes), - nn.ReLU(inplace=True) - ) + padding=padding), + norm(out_planes)) else: self.conv_block = nn.Sequential( nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, - padding=((kernel_size - 1) // 2)), - nn.BatchNorm3d(out_planes), - nn.ReLU(inplace=True)) + padding=padding), + norm(out_planes)) + + if self.skip and self.downsample: + self.conv_down = nn.Sequential( + nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0), + norm(out_planes)) def forward(self, x): - return self.conv_block(x) + y = self.conv_block(x) + if self.skip: + res = x + if self.downsample: + res = self.conv_down(res) + y = y + res + return self.final_activation(y) + # green block in Fig.1 class TranspConv3DBlock(nn.Module): @@ -33,13 +47,22 @@ def __init__(self, in_planes, out_planes): self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0) def forward(self, x): - return self.block(x) + y = self.block(x) + return y + # blue box in Fig.1 class BlueBlock(nn.Module): - def __init__(self, in_planes, out_planes): + def __init__(self, in_planes, out_planes, layers=1): super().__init__() - self.block = nn.Sequential(TranspConv3DBlock(in_planes, out_planes), - Conv3DBlock(out_planes, out_planes,double=False)) + self.blocks = nn.ModuleList([TranspConv3DBlock(in_planes, out_planes), + Conv3DBlock(out_planes, out_planes, double=False)]) + if int(layers)>=2: + for _ in range(int(layers) - 1): + self.blocks.append(TranspConv3DBlock(out_planes, out_planes)) + self.blocks.append(Conv3DBlock(out_planes, out_planes, double=False)) + def forward(self, x): - return self.block(x) \ No newline at end of file + for blk in self.blocks: + x = blk(x) + return x diff --git a/self_attention_cv/version.py b/self_attention_cv/version.py index 3042aca..a00d195 100644 --- a/self_attention_cv/version.py +++ b/self_attention_cv/version.py @@ -1,6 +1,6 @@ import sys -__version__ = "1.2.0" +__version__ = "1.2.1" msg = "Self_attention_cv is only compatible with Python 3.0 and newer." diff --git a/tests/test_unetTR.py b/tests/test_unetTR.py index 00d62c6..8e99eb2 100644 --- a/tests/test_unetTR.py +++ b/tests/test_unetTR.py @@ -4,11 +4,11 @@ def test_unettr(): device = 'cuda' if torch.cuda.is_available() else 'cpu' - model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1, version='a').to(device) + model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1).to(device) a = torch.rand(1, 1, 64, 64, 64).to(device) assert model(a).shape == (1,1,64,64,64) del model - model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1, version='light').to(device) + model = UNETR(img_shape=(64, 64, 64), input_dim=1, output_dim=1).to(device) assert model(a).shape == (1, 1, 64, 64, 64) test_unettr() \ No newline at end of file