Skip to content

Commit

Permalink
Refactor ControlNetMaisi (#8005)
Browse files Browse the repository at this point in the history
Fixes #7988 .

### Description

Refactor ControlNetMaisi to use monai core components.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Pengfei Guo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
3 people authored Aug 13, 2024
1 parent 6858114 commit 9dbfe16
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
33 changes: 15 additions & 18 deletions monai/apps/generation/maisi/networks/controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,15 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, cast
from typing import Sequence

import torch

from monai.utils import optional_import
from monai.networks.nets.controlnet import ControlNet
from monai.networks.nets.diffusion_model_unet import get_timestep_embedding

ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
get_timestep_embedding, has_get_timestep_embedding = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
)

if TYPE_CHECKING:
from generative.networks.nets.controlnet import ControlNet as ControlNetType
else:
ControlNetType = cast(type, ControlNet)


class ControlNetMaisi(ControlNetType):
class ControlNetMaisi(ControlNet):
"""
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
Diffusion Models" (https://arxiv.org/abs/2302.05543)
Expand All @@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
classes.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
use_checkpointing: if True, use activation checkpointing to save memory.
include_fc: whether to include the final linear layer. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
"""

def __init__(
Expand All @@ -71,10 +64,12 @@ def __init__(
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
conditioning_embedding_in_channels: int = 1,
conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
use_checkpointing: bool = True,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__(
spatial_dims,
Expand All @@ -91,9 +86,11 @@ def __init__(
cross_attention_dim,
num_class_embeds,
upcast_attention,
use_flash_attention,
conditioning_embedding_in_channels,
conditioning_embedding_num_channels,
include_fc,
use_combined_linear,
use_flash_attention,
)
self.use_checkpointing = use_checkpointing

Expand All @@ -105,7 +102,7 @@ def forward(
conditioning_scale: float = 1.0,
context: torch.Tensor | None = None,
class_labels: torch.Tensor | None = None,
) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
h = self._apply_initial_convolution(x)
if self.use_checkpointing:
Expand Down
10 changes: 4 additions & 6 deletions monai/networks/nets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,24 +174,22 @@ def __init__(
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
raise ValueError(
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
"ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
"to be specified when with_conditioning=True."
)
if cross_attention_dim is not None and with_conditioning is False:
raise ValueError(
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
)
raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.")

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
raise ValueError(
f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got"
f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
f" channels={channels} and norm_num_groups={norm_num_groups}"
)

if len(channels) != len(attention_levels):
raise ValueError(
f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got "
f"ControlNet expects channels to have the same length as attention_levels, but got "
f"channels={channels} and attention_levels={attention_levels}"
)

Expand Down
19 changes: 10 additions & 9 deletions tests/test_controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi
from monai.networks import eval_mode
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi
_, has_einops = optional_import("einops")

TEST_CASES = [
[
Expand Down Expand Up @@ -103,16 +101,17 @@
TEST_CASES_ERROR = [
[
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None},
"ControlNet expects dimension of the cross-attention conditioning "
"(cross_attention_dim) when using with_conditioning.",
"ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
"to be specified when with_conditioning=True.",
],
[
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2},
"ControlNet expects with_conditioning=True when specifying the cross_attention_dim.",
],
[
{"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16},
"ControlNet expects all num_channels being multiple of norm_num_groups",
f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
f" channels={(8, 16)} and norm_num_groups={16}",
],
[
{
Expand All @@ -122,16 +121,17 @@
"attention_levels": (True,),
"norm_num_groups": 8,
},
"ControlNet expects num_channels being same size of attention_levels",
f"ControlNet expects channels to have the same length as attention_levels, but got "
f"channels={(8, 16)} and attention_levels={(True,)}",
],
]


@SkipIfBeforePyTorchVersion((2, 0))
@skipUnless(has_generative, "monai-generative required")
class TestControlNet(unittest.TestCase):

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
Expand All @@ -145,6 +145,7 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_
self.assertEqual(result[1].shape, expected_shape)

@parameterized.expand(TEST_CASES_CONDITIONAL)
@skipUnless(has_einops, "Requires einops")
def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
Expand Down

0 comments on commit 9dbfe16

Please sign in to comment.