diff --git a/deeplay/components/diffusion/attention_unet.py b/deeplay/components/diffusion/attention_unet.py index 2a27e8ff..1aa9053c 100644 --- a/deeplay/components/diffusion/attention_unet.py +++ b/deeplay/components/diffusion/attention_unet.py @@ -29,49 +29,36 @@ def __init__( ) -class DoubleConvBlock(DeeplayModule): - """ - Connects two base blocks in series either with or without a skip connection. Acts like a residual block. - """ - - def __init__(self, in_channels, out_channels, residual=False): - super().__init__() - self.blocks = LayerList() - self.blocks.append(Block(in_channels, out_channels)) - self.blocks.append(Block(out_channels, out_channels, activation=nn.Identity())) - self.residual = residual - - def forward(self, x): - x_input = x - for block in self.blocks: - x = block(x) - return F.gelu(x_input + x) if self.residual else x - - class AttentionBlock(DeeplayModule): """ - Applies attention mechanism to the input tensor. Based on the input, it can handle both self-attention and cross-attention mechanisms. If context_embedding_dim is provided, it will apply cross-attention else it will apply self-attention. + Applies attention mechanism to the input tensor. Depending on the input, it can handle both self-attention and cross-attention mechanisms. If context_embedding_dim is provided, it will apply cross-attention, else it will apply self-attention. """ - def __init__(self, channels, context_embedding_dim): + def __init__(self, channels, context_embedding_dim, num_attention_heads): super().__init__() self.channels = channels - # Self-attention part of the basic transformer + # Self-attention part of the basic transformer action self.layer_norm1 = Layer(nn.LayerNorm, [channels]) self.self_attention = Layer( - nn.MultiheadAttention, channels, num_heads=1, batch_first=True + nn.MultiheadAttention, + channels, + num_heads=num_attention_heads["self"], + batch_first=True, ) # Cross-attention if context is enabled if context_embedding_dim is not None: self.cross_attention = Layer( - nn.MultiheadAttention, channels, num_heads=1, batch_first=True + nn.MultiheadAttention, + channels, + num_heads=num_attention_heads["cross"], + batch_first=True, ) self.layer_norm2 = Layer(nn.LayerNorm, [channels]) self.context_projection = Layer(nn.Linear, context_embedding_dim, channels) - # Feedforward part of the basic transformer + # Feedforward part of the basic transformer action self.layer_norm3 = Layer(nn.LayerNorm, [channels]) self.feed_forward = Sequential( Layer(nn.Linear, channels, channels), @@ -98,10 +85,9 @@ def forward(self, x, context): x = x + cross_attention_output # Feedforward - x_input = x - x = self.layer_norm3(x) - x = self.feed_forward(x) - x = x + x_input + z = self.layer_norm3(x) + z = self.feed_forward(z) + x = x + z # Reshape back to original shape: [B, H*W, C] -> [B, C, H, W] x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) @@ -110,50 +96,54 @@ def forward(self, x, context): class FeatureIntegrationModule(DeeplayModule): """ - Integrates the features with context information (such as time, context) using positional encoding and self-attention. + Integrates the time and context information to the feature maps through residual connections and attention mechanisms. """ def __init__( self, in_channels, out_channels, - embedding_dim, + position_embedding_dim, context_embedding_dim, + num_attention_heads, enable_attention=True, ): super().__init__() - self.conv_block = ( - Sequential( - DoubleConvBlock(in_channels, in_channels, residual=True), - DoubleConvBlock(in_channels, out_channels), - ) - if enable_attention - else DoubleConvBlock(in_channels, out_channels) - ) - self.feed_forward_positional_embedding = Layer( - nn.Linear, embedding_dim, out_channels + + self.blocks = LayerList() + self.blocks.append(Block(in_channels, out_channels)) + self.blocks.append(Block(out_channels, out_channels)) + + # self.res_block = Block(in_channels, out_channels) + self.res_block = Layer(nn.Conv2d, in_channels, out_channels, kernel_size=1) + + self.feed_forward_position_embedding = Layer( + nn.Linear, position_embedding_dim, out_channels ) + self.enable_attention = enable_attention if self.enable_attention: self.attention_layer = AttentionBlock( - channels=out_channels, context_embedding_dim=context_embedding_dim + channels=out_channels, + context_embedding_dim=context_embedding_dim, + num_attention_heads=num_attention_heads, ) def forward(self, x, t, context): - x = self.conv_block(x) - # Tranform the time step embedding to the channel dimension of the input feature map - emb = self.feed_forward_positional_embedding(t) - # Repeat and reshape the embedding to match the spatial dimensions of the input feature map - emb = emb[:, :, None, None].repeat(1, 1, x.shape[2], x.shape[3]) + h = self.blocks[0](x) - # Add the positional encoding to the input feature map - x = x + emb + # Project the positional encoding and add it to the input feature map + emb = self.feed_forward_position_embedding(t) + h += emb[:, :, None, None] + h = self.blocks[1](h) - # Apply self-attention to the input feature map - if self.enable_attention: - x = self.attention_layer(x, context) + # Residual connection to the input feature map + h += self.res_block(x) - return x + # Apply self-attention if enabled + if self.enable_attention: + h = self.attention_layer(h, context) + return h class UNetEncoder(DeeplayModule): @@ -170,30 +160,28 @@ def __init__( channel_attention, position_embedding_dim, context_embedding_dim, + num_attention_heads, ): super().__init__() - self.conv_block1 = DoubleConvBlock(in_channels, channels[0]) self.blocks = LayerList() - for i in range(len(channels) - 1): - attention_flag = channel_attention[i + 1] + for i in range(len(channels)): + attention_flag = channel_attention[i] self.blocks.append( FeatureIntegrationModule( + in_channels if i == 0 else channels[i - 1], channels[i], - channels[i + 1], position_embedding_dim, + context_embedding_dim, + num_attention_heads, enable_attention=attention_flag, - context_embedding_dim=context_embedding_dim, ) ) + self.pool = Layer(nn.MaxPool2d, kernel_size=2, stride=2) def forward(self, x, t, context): feature_maps = [] - x = self.conv_block1(x) - feature_maps.append(x) - x = self.pool(x) - for block in self.blocks: x = block(x, t, context) feature_maps.append(x) @@ -209,7 +197,12 @@ class UNetDecoder(DeeplayModule): """ def __init__( - self, channels, channel_attention, position_embedding_dim, context_embedding_dim + self, + channels, + channel_attention, + position_embedding_dim, + context_embedding_dim, + num_attention_heads, ): super().__init__() @@ -234,8 +227,9 @@ def __init__( channels[i], channels[i + 1], position_embedding_dim, + context_embedding_dim, + num_attention_heads, enable_attention=attention_flag, - context_embedding_dim=context_embedding_dim, ), ) ) @@ -271,6 +265,8 @@ class AttentionUNet(DeeplayModule): Number of classes. If num_classes are provided, the class embedding will be added to the positional encoding. This is used for the class conditioned models. context_embedding_dim : Optional[int] Dimension of the context embedding. Context embedding is defined outside the model and passed as an input to the model. The dimension of the context embedding should match the dimension given to the model. When enabled, the context embedding will be used to apply cross-attention to the feature maps. + num_attention_heads : dict + Number of attention heads for self-attention and cross-attention mechanisms. The keys should be "self" and "cross" respectively. Default is {"self": 1, "cross": 1}. """ in_channels: int @@ -281,6 +277,7 @@ class AttentionUNet(DeeplayModule): position_embedding_dim: int num_classes: Optional[int] context_embedding_dim: Optional[int] + num_attention_heads: dict def __init__( self, @@ -292,10 +289,12 @@ def __init__( position_embedding_dim: int = 16, num_classes: Optional[int] = None, context_embedding_dim: Optional[int] = None, + num_attention_heads: dict = {"self": 1, "cross": 1}, ): super().__init__() self.position_embedding_dim = position_embedding_dim self.context_embedding_dim = context_embedding_dim + self.num_attention_heads = num_attention_heads # Class embedding if num_classes is not None: @@ -306,7 +305,7 @@ def __init__( # Checks if len(channel_attention) != len(channels): raise ValueError( - "The number of attention flags should be equal to the number of channels." + "Length of channel_attention should be equal to the length of channels" ) # UNet encoder @@ -316,20 +315,22 @@ def __init__( channel_attention, position_embedding_dim, context_embedding_dim, + num_attention_heads, ) - # Base blocks self.base_blocks = LayerList() - self.base_blocks.append(DoubleConvBlock(channels[-1], base_channels[0])) + self.base_blocks.append(Block(channels[-1], base_channels[0])) for i in range(len(base_channels) - 1): - self.base_blocks.append( - DoubleConvBlock(base_channels[i], base_channels[i + 1]) - ) - self.base_blocks.append(DoubleConvBlock(base_channels[-1], channels[-1])) + self.base_blocks.append(Block(base_channels[i], base_channels[i + 1])) + self.base_blocks.append(Block(base_channels[-1], channels[-1])) # UNet decoder self.decoder = UNetDecoder( - channels, channel_attention, position_embedding_dim, context_embedding_dim + channels, + channel_attention, + position_embedding_dim, + context_embedding_dim, + num_attention_heads, ) # Output layer @@ -349,23 +350,17 @@ def forward(self, x, t, y=None, context=None): + "Please make sure that the embedding dimensions given to the model and the positional encoding function match." ) - if self.context_embedding_dim is not None: - if context is None: - raise ValueError( - "Context embedding is enabled. Please provide the context embedding." - ) - if context is not None: if context.shape[-1] != self.context_embedding_dim: raise ValueError( "Embedding dimension mismatch. " + f"Expected: {self.context_embedding_dim}, Got: {context.shape[2]}. " - + "Please make sure that the embedding dimensions given to the model and the context dimension provided in forward function match." + + "Please make sure that the context embedding dimensions provided while instantiating the model and the context embedding dimensions match." ) if y is not None: y = self.class_embedding(y) - t = t + y + t += y feature_maps = self.encoder(x, t, context) @@ -388,6 +383,7 @@ def configure( position_embedding_dim: int = 256, num_classes: Optional[int] = None, context_embedding_dim: Optional[int] = None, + num_attention_heads: dict = {"self": 1, "cross": 1}, ) -> None: ... configure = DeeplayModule.configure diff --git a/deeplay/tests/test_attention_unet.py b/deeplay/tests/test_attention_unet.py index 7af46d21..cc900913 100644 --- a/deeplay/tests/test_attention_unet.py +++ b/deeplay/tests/test_attention_unet.py @@ -121,3 +121,24 @@ def test_context_without_channel_attention(self): # Check output shape self.assertEqual(output.shape, (2, 1, 64, 64)) + + def test_self_attention_heads(self): + + attn_unet = AttentionUNet( + in_channels=1, + channels=[8, 16, 32], + base_channels=[64, 64], + channel_attention=[True, True, True], + out_channels=1, + position_embedding_dim=64, + num_attention_heads={"self": 2, "cross": 2}, + ) + attn_unet.build() + + # Test on a batch of 2 + x = torch.rand(2, 1, 64, 64) + t = torch.rand(2, 1) + output = attn_unet(x, positional_encoding(t, 64), y=None, context=None) + + # Check output shape + self.assertEqual(output.shape, (2, 1, 64, 64))