Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Attention UNet #115

Merged
merged 2 commits into from
Jun 20, 2024

Conversation

HarshithBachimanchi
Copy link
Collaborator

This PR aims to refactor AttentionUNet. The following changes are made to make it modularly simple, and produce a much cleaner print statement:

  1. DoubleConvBlock is now removed. It was used to implement residual connections at all levels in the UNet. Instead, now they are directly integrated to the FeatureIntegrationModule through a couple of Block modules.
  2. The time step encoding (and class and context embeddings) are now integrated at the middle of the residual connections, rather than at the end.
  3. Added a new parameter num_attention_heads to control the number of attention heads in self-attention and cross-attention heads. Also added a unittest for this.
  4. Removed a warning that was disabling classifier-free guidance for context inputs.

To refactor into deeply style, I see that several styles need to be implemented for Conv2dBlock, followed by their integration to UNet2d. Some of them exist, but not in the way I want (For example styles, spatial_self_attention, and spatial_cross_attention). I gave it a quick try, and looks like it is possible to build it with styles but requires extensive testing. I prefer to keep AttentionUNet bespoke for now (Unless if you have any suggestions).

1. Removed DoubleConvBlock which was being used for the residual connection. Instead a skip connection is directly included in the FeatureIntegrationModule.
2. Time step embedding (and other embeddings) are now added within the residual block, rather than at the end.
3. Removed some warnings.
4. Added `num_attention_heads` parameter to control the self attention and cross attention heads in the model
@giovannivolpe giovannivolpe merged commit cf60b97 into DeepTrackAI:develop Jun 20, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants