Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Attention UNet #115

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 76 additions & 80 deletions deeplay/components/diffusion/attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,36 @@ def __init__(
)


class DoubleConvBlock(DeeplayModule):
"""
Connects two base blocks in series either with or without a skip connection. Acts like a residual block.
"""

def __init__(self, in_channels, out_channels, residual=False):
super().__init__()
self.blocks = LayerList()
self.blocks.append(Block(in_channels, out_channels))
self.blocks.append(Block(out_channels, out_channels, activation=nn.Identity()))
self.residual = residual

def forward(self, x):
x_input = x
for block in self.blocks:
x = block(x)
return F.gelu(x_input + x) if self.residual else x


class AttentionBlock(DeeplayModule):
"""
Applies attention mechanism to the input tensor. Based on the input, it can handle both self-attention and cross-attention mechanisms. If context_embedding_dim is provided, it will apply cross-attention else it will apply self-attention.
Applies attention mechanism to the input tensor. Depending on the input, it can handle both self-attention and cross-attention mechanisms. If context_embedding_dim is provided, it will apply cross-attention, else it will apply self-attention.
"""

def __init__(self, channels, context_embedding_dim):
def __init__(self, channels, context_embedding_dim, num_attention_heads):
super().__init__()
self.channels = channels

# Self-attention part of the basic transformer
# Self-attention part of the basic transformer action
self.layer_norm1 = Layer(nn.LayerNorm, [channels])
self.self_attention = Layer(
nn.MultiheadAttention, channels, num_heads=1, batch_first=True
nn.MultiheadAttention,
channels,
num_heads=num_attention_heads["self"],
batch_first=True,
)

# Cross-attention if context is enabled
if context_embedding_dim is not None:
self.cross_attention = Layer(
nn.MultiheadAttention, channels, num_heads=1, batch_first=True
nn.MultiheadAttention,
channels,
num_heads=num_attention_heads["cross"],
batch_first=True,
)
self.layer_norm2 = Layer(nn.LayerNorm, [channels])
self.context_projection = Layer(nn.Linear, context_embedding_dim, channels)

# Feedforward part of the basic transformer
# Feedforward part of the basic transformer action
self.layer_norm3 = Layer(nn.LayerNorm, [channels])
self.feed_forward = Sequential(
Layer(nn.Linear, channels, channels),
Expand All @@ -98,10 +85,9 @@ def forward(self, x, context):
x = x + cross_attention_output

# Feedforward
x_input = x
x = self.layer_norm3(x)
x = self.feed_forward(x)
x = x + x_input
z = self.layer_norm3(x)
z = self.feed_forward(z)
x = x + z

# Reshape back to original shape: [B, H*W, C] -> [B, C, H, W]
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
Expand All @@ -110,50 +96,54 @@ def forward(self, x, context):

class FeatureIntegrationModule(DeeplayModule):
"""
Integrates the features with context information (such as time, context) using positional encoding and self-attention.
Integrates the time and context information to the feature maps through residual connections and attention mechanisms.
"""

def __init__(
self,
in_channels,
out_channels,
embedding_dim,
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
enable_attention=True,
):
super().__init__()
self.conv_block = (
Sequential(
DoubleConvBlock(in_channels, in_channels, residual=True),
DoubleConvBlock(in_channels, out_channels),
)
if enable_attention
else DoubleConvBlock(in_channels, out_channels)
)
self.feed_forward_positional_embedding = Layer(
nn.Linear, embedding_dim, out_channels

self.blocks = LayerList()
self.blocks.append(Block(in_channels, out_channels))
self.blocks.append(Block(out_channels, out_channels))

# self.res_block = Block(in_channels, out_channels)
self.res_block = Layer(nn.Conv2d, in_channels, out_channels, kernel_size=1)

self.feed_forward_position_embedding = Layer(
nn.Linear, position_embedding_dim, out_channels
)

self.enable_attention = enable_attention
if self.enable_attention:
self.attention_layer = AttentionBlock(
channels=out_channels, context_embedding_dim=context_embedding_dim
channels=out_channels,
context_embedding_dim=context_embedding_dim,
num_attention_heads=num_attention_heads,
)

def forward(self, x, t, context):
x = self.conv_block(x)
# Tranform the time step embedding to the channel dimension of the input feature map
emb = self.feed_forward_positional_embedding(t)
# Repeat and reshape the embedding to match the spatial dimensions of the input feature map
emb = emb[:, :, None, None].repeat(1, 1, x.shape[2], x.shape[3])
h = self.blocks[0](x)

# Add the positional encoding to the input feature map
x = x + emb
# Project the positional encoding and add it to the input feature map
emb = self.feed_forward_position_embedding(t)
h += emb[:, :, None, None]
h = self.blocks[1](h)

# Apply self-attention to the input feature map
if self.enable_attention:
x = self.attention_layer(x, context)
# Residual connection to the input feature map
h += self.res_block(x)

return x
# Apply self-attention if enabled
if self.enable_attention:
h = self.attention_layer(h, context)
return h


class UNetEncoder(DeeplayModule):
Expand All @@ -170,30 +160,28 @@ def __init__(
channel_attention,
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
):
super().__init__()
self.conv_block1 = DoubleConvBlock(in_channels, channels[0])
self.blocks = LayerList()

for i in range(len(channels) - 1):
attention_flag = channel_attention[i + 1]
for i in range(len(channels)):
attention_flag = channel_attention[i]
self.blocks.append(
FeatureIntegrationModule(
in_channels if i == 0 else channels[i - 1],
channels[i],
channels[i + 1],
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
enable_attention=attention_flag,
context_embedding_dim=context_embedding_dim,
)
)

self.pool = Layer(nn.MaxPool2d, kernel_size=2, stride=2)

def forward(self, x, t, context):
feature_maps = []
x = self.conv_block1(x)
feature_maps.append(x)
x = self.pool(x)

for block in self.blocks:
x = block(x, t, context)
feature_maps.append(x)
Expand All @@ -209,7 +197,12 @@ class UNetDecoder(DeeplayModule):
"""

def __init__(
self, channels, channel_attention, position_embedding_dim, context_embedding_dim
self,
channels,
channel_attention,
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
):
super().__init__()

Expand All @@ -234,8 +227,9 @@ def __init__(
channels[i],
channels[i + 1],
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
enable_attention=attention_flag,
context_embedding_dim=context_embedding_dim,
),
)
)
Expand Down Expand Up @@ -271,6 +265,8 @@ class AttentionUNet(DeeplayModule):
Number of classes. If num_classes are provided, the class embedding will be added to the positional encoding. This is used for the class conditioned models.
context_embedding_dim : Optional[int]
Dimension of the context embedding. Context embedding is defined outside the model and passed as an input to the model. The dimension of the context embedding should match the dimension given to the model. When enabled, the context embedding will be used to apply cross-attention to the feature maps.
num_attention_heads : dict
Number of attention heads for self-attention and cross-attention mechanisms. The keys should be "self" and "cross" respectively. Default is {"self": 1, "cross": 1}.
"""

in_channels: int
Expand All @@ -281,6 +277,7 @@ class AttentionUNet(DeeplayModule):
position_embedding_dim: int
num_classes: Optional[int]
context_embedding_dim: Optional[int]
num_attention_heads: dict

def __init__(
self,
Expand All @@ -292,10 +289,12 @@ def __init__(
position_embedding_dim: int = 16,
num_classes: Optional[int] = None,
context_embedding_dim: Optional[int] = None,
num_attention_heads: dict = {"self": 1, "cross": 1},
):
super().__init__()
self.position_embedding_dim = position_embedding_dim
self.context_embedding_dim = context_embedding_dim
self.num_attention_heads = num_attention_heads

# Class embedding
if num_classes is not None:
Expand All @@ -306,7 +305,7 @@ def __init__(
# Checks
if len(channel_attention) != len(channels):
raise ValueError(
"The number of attention flags should be equal to the number of channels."
"Length of channel_attention should be equal to the length of channels"
)

# UNet encoder
Expand All @@ -316,20 +315,22 @@ def __init__(
channel_attention,
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
)

# Base blocks
self.base_blocks = LayerList()
self.base_blocks.append(DoubleConvBlock(channels[-1], base_channels[0]))
self.base_blocks.append(Block(channels[-1], base_channels[0]))
for i in range(len(base_channels) - 1):
self.base_blocks.append(
DoubleConvBlock(base_channels[i], base_channels[i + 1])
)
self.base_blocks.append(DoubleConvBlock(base_channels[-1], channels[-1]))
self.base_blocks.append(Block(base_channels[i], base_channels[i + 1]))
self.base_blocks.append(Block(base_channels[-1], channels[-1]))

# UNet decoder
self.decoder = UNetDecoder(
channels, channel_attention, position_embedding_dim, context_embedding_dim
channels,
channel_attention,
position_embedding_dim,
context_embedding_dim,
num_attention_heads,
)

# Output layer
Expand All @@ -349,23 +350,17 @@ def forward(self, x, t, y=None, context=None):
+ "Please make sure that the embedding dimensions given to the model and the positional encoding function match."
)

if self.context_embedding_dim is not None:
if context is None:
raise ValueError(
"Context embedding is enabled. Please provide the context embedding."
)

if context is not None:
if context.shape[-1] != self.context_embedding_dim:
raise ValueError(
"Embedding dimension mismatch. "
+ f"Expected: {self.context_embedding_dim}, Got: {context.shape[2]}. "
+ "Please make sure that the embedding dimensions given to the model and the context dimension provided in forward function match."
+ "Please make sure that the context embedding dimensions provided while instantiating the model and the context embedding dimensions match."
)

if y is not None:
y = self.class_embedding(y)
t = t + y
t += y

feature_maps = self.encoder(x, t, context)

Expand All @@ -388,6 +383,7 @@ def configure(
position_embedding_dim: int = 256,
num_classes: Optional[int] = None,
context_embedding_dim: Optional[int] = None,
num_attention_heads: dict = {"self": 1, "cross": 1},
) -> None: ...

configure = DeeplayModule.configure
21 changes: 21 additions & 0 deletions deeplay/tests/test_attention_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,24 @@ def test_context_without_channel_attention(self):

# Check output shape
self.assertEqual(output.shape, (2, 1, 64, 64))

def test_self_attention_heads(self):

attn_unet = AttentionUNet(
in_channels=1,
channels=[8, 16, 32],
base_channels=[64, 64],
channel_attention=[True, True, True],
out_channels=1,
position_embedding_dim=64,
num_attention_heads={"self": 2, "cross": 2},
)
attn_unet.build()

# Test on a batch of 2
x = torch.rand(2, 1, 64, 64)
t = torch.rand(2, 1)
output = attn_unet(x, positional_encoding(t, 64), y=None, context=None)

# Check output shape
self.assertEqual(output.shape, (2, 1, 64, 64))
Loading