From 2fb3901283a7a5fe8e7779c428639ac424915fb9 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Fri, 13 Dec 2024 15:54:56 -0800 Subject: [PATCH 1/2] vae training Signed-off-by: linnan wang --- nemo/collections/diffusion/vae/autoencoder.py | 225 ++++++++++- nemo/collections/diffusion/vae/autovae.py | 305 +++++++++++++++ nemo/collections/diffusion/vae/blocks.py | 147 ++++++- .../diffusion/vae/contperceptual_loss.py | 183 +++++++++ .../diffusion/vae/diffusers_vae.py | 32 ++ nemo/collections/diffusion/vae/readme.rst | 131 +++++++ .../collections/diffusion/vae/test_autovae.py | 144 +++++++ nemo/collections/diffusion/vae/train_vae.py | 365 ++++++++++++++++++ nemo/collections/diffusion/vae/train_vae.sh | 10 + .../diffusion/vae/vae16x/config.json | 35 ++ .../collections/diffusion/vae/validate_vae.py | 49 +++ 11 files changed, 1613 insertions(+), 13 deletions(-) create mode 100644 nemo/collections/diffusion/vae/autovae.py create mode 100644 nemo/collections/diffusion/vae/contperceptual_loss.py create mode 100644 nemo/collections/diffusion/vae/readme.rst create mode 100644 nemo/collections/diffusion/vae/test_autovae.py create mode 100644 nemo/collections/diffusion/vae/train_vae.py create mode 100644 nemo/collections/diffusion/vae/train_vae.sh create mode 100644 nemo/collections/diffusion/vae/vae16x/config.json create mode 100644 nemo/collections/diffusion/vae/validate_vae.py diff --git a/nemo/collections/diffusion/vae/autoencoder.py b/nemo/collections/diffusion/vae/autoencoder.py index b356d74baac1..234b8052b449 100644 --- a/nemo/collections/diffusion/vae/autoencoder.py +++ b/nemo/collections/diffusion/vae/autoencoder.py @@ -18,11 +18,45 @@ import torch from torch import Tensor, nn -from nemo.collections.diffusion.vae.blocks import AttnBlock, Downsample, Normalize, ResnetBlock, Upsample, make_attn +from nemo.collections.diffusion.vae.blocks import Downsample, Normalize, ResnetBlock, Upsample, make_attn @dataclass class AutoEncoderParams: + """Dataclass for storing autoencoder hyperparameters. + + Attributes + ---------- + ch_mult : list[int] + Channel multipliers at each resolution level. + attn_resolutions : list[int] + List of resolutions at which attention layers are applied. + resolution : int, optional + Input image resolution. Default is 256. + in_channels : int, optional + Number of input channels. Default is 3. + ch : int, optional + Base channel dimension. Default is 128. + out_ch : int, optional + Number of output channels. Default is 3. + num_res_blocks : int, optional + Number of residual blocks at each resolution. Default is 2. + z_channels : int, optional + Number of latent channels in the compressed representation. Default is 16. + scale_factor : float, optional + Scaling factor for latent representations. Default is 0.3611. + shift_factor : float, optional + Shift factor for latent representations. Default is 0.1159. + attn_type : str, optional + Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'. + double_z : bool, optional + If True, produce both mean and log-variance for latent space. Default is True. + dropout : float, optional + Dropout rate. Default is 0.0. + ckpt : str or None, optional + Path to checkpoint file for loading pretrained weights. Default is None. + """ + ch_mult: list[int] attn_resolutions: list[int] resolution: int = 256 @@ -39,12 +73,55 @@ class AutoEncoderParams: ckpt: str = None -def nonlinearity(x): - # swish +def nonlinearity(x: Tensor) -> Tensor: + """Applies the SiLU (Swish) nonlinearity. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transformed tensor after applying SiLU activation. + """ return torch.nn.functional.silu(x) class Encoder(nn.Module): + """Encoder module that downsamples and encodes input images into a latent representation. + + Parameters + ---------- + ch : int + Base channel dimension. + out_ch : int + Number of output channels. + ch_mult : list[int] + Channel multipliers at each resolution level. + num_res_blocks : int + Number of residual blocks at each resolution level. + attn_resolutions : list[int] + List of resolutions at which attention layers are applied. + in_channels : int + Number of input image channels. + resolution : int + Input image resolution. + z_channels : int + Number of latent channels. + dropout : float, optional + Dropout rate. Default is 0.0. + resamp_with_conv : bool, optional + Whether to use convolutional resampling. Default is True. + double_z : bool, optional + If True, produce mean and log-variance channels for latent space. Default is True. + use_linear_attn : bool, optional + If True, use linear attention. Default is False. + attn_type : str, optional + Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'. + """ + def __init__( self, *, @@ -117,7 +194,19 @@ def __init__( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 ) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the Encoder. + + Parameters + ---------- + x : torch.Tensor + Input image tensor of shape (B, C, H, W). + + Returns + ------- + torch.Tensor + Latent representation before sampling, with shape (B, 2*z_channels, H', W') if double_z=True. + """ # timestep embedding temb = None @@ -146,6 +235,40 @@ def forward(self, x): class Decoder(nn.Module): + """Decoder module that upscales and decodes latent representations back into images. + + Parameters + ---------- + ch : int + Base channel dimension. + out_ch : int + Number of output channels (e.g. 3 for RGB). + ch_mult : list[int] + Channel multipliers at each resolution level. + num_res_blocks : int + Number of residual blocks at each resolution level. + attn_resolutions : list[int] + List of resolutions at which attention layers are applied. + in_channels : int + Number of input image channels. + resolution : int + Input image resolution. + z_channels : int + Number of latent channels. + dropout : float, optional + Dropout rate. Default is 0.0. + resamp_with_conv : bool, optional + Whether to use convolutional resampling. Default is True. + give_pre_end : bool, optional + If True, returns the tensor before the final normalization and convolution. Default is False. + tanh_out : bool, optional + If True, applies a tanh activation to the output. Default is False. + use_linear_attn : bool, optional + If True, use linear attention. Default is False. + attn_type : str, optional + Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'. + """ + def __init__( self, *, @@ -224,8 +347,19 @@ def __init__( self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] + def forward(self, z: Tensor) -> Tensor: + """Forward pass of the Decoder. + + Parameters + ---------- + z : torch.Tensor + Latent representation of shape (B, z_channels, H', W'). + + Returns + ------- + torch.Tensor + Decoded image of shape (B, out_ch, H, W). + """ self.last_z_shape = z.shape # timestep embedding @@ -261,12 +395,35 @@ def forward(self, z): class DiagonalGaussian(nn.Module): + """Module that splits an input tensor into mean and log-variance and optionally samples from the Gaussian. + + Parameters + ---------- + sample : bool, optional + If True, return a sample from the Gaussian. Otherwise, return the mean. Default is True. + chunk_dim : int, optional + Dimension along which to chunk the tensor into mean and log-variance. Default is 1. + """ + def __init__(self, sample: bool = True, chunk_dim: int = 1): super().__init__() self.sample = sample self.chunk_dim = chunk_dim def forward(self, z: Tensor) -> Tensor: + """Forward pass of the DiagonalGaussian module. + + Parameters + ---------- + z : torch.Tensor + Input tensor of shape (..., 2*z_channels, ...). + + Returns + ------- + torch.Tensor + If sample=True, returns a sampled tensor from N(mean, var). + If sample=False, returns the mean. + """ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) @@ -276,6 +433,14 @@ def forward(self, z: Tensor) -> Tensor: class AutoEncoder(nn.Module): + """Full AutoEncoder model combining an Encoder, Decoder, and latent Gaussian sampling. + + Parameters + ---------- + params : AutoEncoderParams + Configuration parameters for the AutoEncoder model. + """ + def __init__(self, params: AutoEncoderParams): super().__init__() self.encoder = Encoder( @@ -314,21 +479,65 @@ def __init__(self, params: AutoEncoderParams): self.load_from_checkpoint(params.ckpt) def encode(self, x: Tensor) -> Tensor: + """Encode an input image to its latent representation. + + Parameters + ---------- + x : torch.Tensor + Input image of shape (B, C, H, W). + + Returns + ------- + torch.Tensor + Latent representation of the input image. + """ z = self.reg(self.encoder(x)) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: + """Decode a latent representation back into an image. + + Parameters + ---------- + z : torch.Tensor + Latent representation of shape (B, z_channels, H', W'). + + Returns + ------- + torch.Tensor + Reconstructed image of shape (B, out_ch, H, W). + """ z = z / self.scale_factor + self.shift_factor return self.decoder(z) def forward(self, x: Tensor) -> Tensor: + """Forward pass that encodes and decodes the input image. + + Parameters + ---------- + x : torch.Tensor + Input image tensor. + + Returns + ------- + torch.Tensor + Reconstructed image. + """ return self.decode(self.encode(x)) - def load_from_checkpoint(self, ckpt_path): + def load_from_checkpoint(self, ckpt_path: str): + """Load the autoencoder weights from a checkpoint file. + + Parameters + ---------- + ckpt_path : str + Path to the checkpoint file. + """ from safetensors.torch import load_file as load_sft state_dict = load_sft(ckpt_path) missing, unexpected = self.load_state_dict(state_dict) if len(missing) > 0: - logger.warning(f"Following keys are missing from checkpoint loaded: {missing}") + # If logger is not defined, you may replace this with print or similar. + print(f"Warning: Following keys are missing from checkpoint loaded: {missing}") diff --git a/nemo/collections/diffusion/vae/autovae.py b/nemo/collections/diffusion/vae/autovae.py new file mode 100644 index 000000000000..a7642886ca6c --- /dev/null +++ b/nemo/collections/diffusion/vae/autovae.py @@ -0,0 +1,305 @@ +import itertools +import time +from typing import Dict, List + +import torch +import torch.profiler +from diffusers import AutoencoderKL +from torch import nn + + +class VAEGenerator: + """ + A class for generating and searching different Variational Autoencoder (VAE) configurations. + + This class provides functionality to generate various VAE architecture configurations + given a specific input resolution and compression ratio. It allows searching through a + design space to find configurations that match given parameter and memory budgets. + """ + + def __init__(self, input_resolution: int = 1024, compression_ratio: int = 16) -> None: + if input_resolution == 1024: + assert compression_ratio in [8, 16] + elif input_resolution == 2048: + assert compression_ratio in [8, 16, 32] + else: + raise NotImplementedError("Higher resolution than 2028 is not implemented yet!") + + self._input_resolution = input_resolution + self._compression_ratio = compression_ratio + + def _generate_input(self): + """ + Generate a random input tensor with the specified input resolution. + + The tensor is placed on the GPU in half-precision (float16). + """ + random_tensor = torch.rand(1, 3, self.input_resolution, self.input_resolution) + random_tensor = random_tensor.to(dtype=torch.float16, device="cuda") + return random_tensor + + def _count_parameters(self, model: nn.Module = None): + """ + Count the number of trainable parameters in a given model. + + Args: + model (nn.Module): The model for which to count parameters. + + Returns: + int: The number of trainable parameters. + """ + assert model is not None, "Please provide a nn.Module to count the parameters." + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + def _load_base_json_skeleton(self): + """ + Load a base configuration skeleton for the VAE. + + Returns: + dict: A dictionary representing the base configuration JSON skeleton. + """ + skeleton = { + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.20.0.dev0", + "_name_or_path": "../sdxl-vae/", + "act_fn": "silu", + "block_out_channels": [], + "down_block_types": [], + "force_upcast": False, + "in_channels": 3, + "latent_channels": -1, # 16 + "layers_per_block": -1, # 2 + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 1024, # resolution size + "scaling_factor": 0.13025, + "up_block_types": [], + } + return skeleton + + def _generate_all_combinations(self, attr): + """ + Generates all possible combinations from a search space dictionary. + + Args: + attr (dict): A dictionary where each key has a list of possible values. + + Returns: + List[Dict]: A list of dictionaries, each representing a unique combination of attributes. + """ + keys = list(attr.keys()) + choices = [attr[key] for key in keys] + all_combinations = list(itertools.product(*choices)) + + combination_dicts = [] + for combination in all_combinations: + combination_dict = {key: value for key, value in zip(keys, combination)} + combination_dicts.append(combination_dict) + + return combination_dicts + + def _assign_attributes(self, choice): + """ + Assign a chosen set of attributes to the base VAE configuration skeleton. + + Args: + choice (dict): A dictionary of attributes to assign to the skeleton. + + Returns: + dict: A dictionary representing the updated VAE configuration. + """ + search_space_skleton = self._load_base_json_skeleton() + search_space_skleton["down_block_types"] = choice["down_block_types"] + search_space_skleton["up_block_types"] = choice["up_block_types"] + search_space_skleton["block_out_channels"] = choice["block_out_channels"] + search_space_skleton["layers_per_block"] = choice["layers_per_block"] + search_space_skleton["latent_channels"] = choice["latent_channels"] + return search_space_skleton + + def _search_space_16x1024(self): + """ + Define the search space for a 16x compression ratio at 1024 resolution. + + Returns: + dict: A dictionary defining lists of possible attribute values. + """ + attr = {} + attr["down_block_types"] = [["DownEncoderBlock2D"] * 5] + attr["up_block_types"] = [["UpDecoderBlock2D"] * 5] + attr["block_out_channels"] = [ + [128, 256, 512, 512, 512], + [128, 256, 512, 512, 1024], + [128, 256, 512, 1024, 2048], + [64, 128, 256, 512, 512], + ] + attr["layers_per_block"] = [1, 2, 3] + attr["latent_channels"] = [4, 16, 32, 64] + return attr + + def _search_space_8x1024(self): + """ + Define the search space for an 8x compression ratio at 1024 resolution. + + Returns: + dict: A dictionary defining lists of possible attribute values. + """ + attr = {} + attr["down_block_types"] = [["DownEncoderBlock2D"] * 4] + attr["up_block_types"] = [["UpDecoderBlock2D"] * 4] + attr["block_out_channels"] = [[128, 256, 512, 512], [128, 256, 512, 1024], [64, 128, 256, 512]] + attr["layers_per_block"] = [1, 2, 3] + attr["latent_channels"] = [4, 16, 32, 64] + return attr + + def _sort_data_in_place(self, data: List[Dict], mode: str) -> None: + """ + Sort the list of design configurations in place based on a chosen mode. + + Args: + data (List[Dict]): A list of dictionaries representing design configurations. + mode (str): The sorting criterion. Can be 'abs_param_diff', 'abs_cuda_mem_diff', or 'mse'. + """ + if mode == 'abs_param_diff': + data.sort(key=lambda x: abs(x['param_diff'])) + elif mode == 'abs_cuda_mem_diff': + data.sort(key=lambda x: abs(x['cuda_mem_diff'])) + elif mode == 'mse': + data.sort(key=lambda x: (x['param_diff'] ** 2 + x['cuda_mem_diff'] ** 2) / 2) + else: + raise ValueError("Invalid mode. Choose from 'abs_param_diff', 'abs_cuda_mem_diff', 'mse'.") + + def _print_table(self, data, headers, col_widths): + """ + Print a formatted table of the design choices. + + Args: + data (List[Dict]): The data to print, each entry a design configuration. + headers (List[str]): Column headers. + col_widths (List[int]): Widths for each column. + """ + # Create header row + header_row = "" + for header, width in zip(headers, col_widths): + header_row += f"{header:<{width}}" + print(header_row) + print("-" * sum(col_widths)) + + # Print each data row + for item in data: + row = f"{item['param_diff']:<{col_widths[0]}}" + row += f"{item['cuda_mem_diff']:<{col_widths[1]}}" + print(row) + + def search_for_target_vae(self, parameters_budget=0, cuda_max_mem=0): + """ + Search through available VAE design choices to find one that best matches + the given parameter and memory budgets. + + Args: + parameters_budget (float, optional): The target number of parameters (in millions). + cuda_max_mem (float, optional): The target maximum GPU memory usage (in MB). + + Returns: + AutoencoderKL: The chosen VAE configuration that best matches the provided budgets. + """ + if parameters_budget <= 0 and cuda_max_mem <= 0: + raise ValueError("Please specify a valid parameter budget or cuda max memory budget") + + search_space_choices = [] + if self.input_resolution == 1024 and self.compression_ratio == 8: + search_space = self._search_space_8x1024() + search_space_choices = self._generate_all_combinations(search_space) + elif self.input_resolution == 1024 and self.compression_ratio == 16: + search_space = self._search_space_16x1024() + search_space_choices = self._generate_all_combinations(search_space) + + inp_tensor = self._generate_input() + inp_tensor = inp_tensor.to(dtype=torch.float16, device="cuda") + design_choices = [] + + for choice in search_space_choices: + parameters_budget_diff = 0 + cuda_max_mem_diff = 0 + + curt_design_json = self._assign_attributes(choice) + print("-" * 20) + print(choice) + vae = AutoencoderKL.from_config(curt_design_json) + vae = vae.to(dtype=torch.float16, device="cuda") + total_params = self._count_parameters(vae) + total_params /= 10**6 + # Reset peak memory statistics + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + profile_memory=True, # Enables memory profiling + record_shapes=True, # Records tensor shapes + with_stack=True, # Records stack traces + ) as prof: + # Perform forward pass + start_time = time.perf_counter() + with torch.no_grad(): + _ = vae.encode(inp_tensor).latent_dist.sample() + torch.cuda.synchronize() + end_time = time.perf_counter() + + total_execution_time_ms = (end_time - start_time) * 1000 + + # Get maximum memory allocated + max_memory_allocated = torch.cuda.max_memory_allocated() + max_memory_allocated = max_memory_allocated / (1024**2) + + parameters_budget_diff = parameters_budget - total_params + cuda_max_mem_diff = cuda_max_mem - max_memory_allocated + design_choices.append( + {"param_diff": parameters_budget_diff, "cuda_mem_diff": cuda_max_mem_diff, "design": curt_design_json} + ) + + print(f" Total params: {total_params}") + print(f" Max GPU Memory Usage: {max_memory_allocated} MB") + print(f" Total Execution Time: {total_execution_time_ms:.2f} ms") + + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + print("-" * 20) + sort_mode = "abs_param_diff" + if parameters_budget == 0: + sort_mode = "abs_cuda_mem_diff" + elif cuda_max_mem == 0: + sort_mode = "abs_param_diff" + else: + sort_mode = "mse" + + print("#" * 20) + self._sort_data_in_place(design_choices, sort_mode) + headers = ["param_diff (M)", "cuda_mem_diff (MB)"] + col_widths = [12, 15] + self._print_table(design_choices, headers, col_widths) + + vae = AutoencoderKL.from_config(design_choices[0]["design"]) + return vae + + @property + def input_resolution(self) -> int: + """ + Get the input resolution for the VAE. + + Returns: + int: The input resolution. + """ + return self._input_resolution + + @property + def compression_ratio(self) -> float: + """ + Get the compression ratio for the VAE. + + Returns: + float: The compression ratio. + """ + return self._compression_ratio diff --git a/nemo/collections/diffusion/vae/blocks.py b/nemo/collections/diffusion/vae/blocks.py index ad38a7a463cf..d942ba1ef4b0 100644 --- a/nemo/collections/diffusion/vae/blocks.py +++ b/nemo/collections/diffusion/vae/blocks.py @@ -26,11 +26,49 @@ def Normalize(in_channels, num_groups=32, act=""): + """Creates a group normalization layer with specified activation. + + Args: + in_channels (int): Number of channels in the input. + num_groups (int, optional): Number of groups for GroupNorm. Defaults to 32. + act (str, optional): Activation function name. Defaults to "". + + Returns: + GroupNorm: A normalization layer with optional activation. + """ return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) +def nonlinearity(x): + """Nonlinearity function used in temporal embedding projection. + + Currently implemented as a SiLU (Swish) function. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output after applying SiLU activation. + """ + return x * torch.sigmoid(x) + + class ResnetBlock(nn.Module): + """A ResNet-style block that can optionally apply a temporal embedding and shortcut projections. + + This block consists of two convolutional layers, normalization, and optional temporal embedding. + It can adjust channel dimensions between input and output via shortcuts. + """ + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=0): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int, optional): Number of output channels. Defaults to in_channels. + conv_shortcut (bool, optional): Whether to use a convolutional shortcut. Defaults to False. + dropout (float, optional): Dropout probability. Defaults to 0.0. + temb_channels (int, optional): Number of channels in temporal embedding. Defaults to 0. + """ super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -51,6 +89,15 @@ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout= self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): + """Forward pass of the ResnetBlock. + + Args: + x (Tensor): Input feature map of shape (B, C, H, W). + temb (Tensor): Temporal embedding tensor of shape (B, temb_channels). + + Returns: + Tensor: Output feature map of shape (B, out_channels, H, W). + """ h = x h = self.norm1(h) h = self.conv1(h) @@ -72,16 +119,32 @@ def forward(self, x, temb): class Upsample(nn.Module): + """Upsampling block that increases spatial resolution by a factor of 2. + + Can optionally include a convolution after upsampling. + """ + def __init__(self, in_channels, with_conv): + """ + Args: + in_channels (int): Number of input channels. + with_conv (bool): If True, apply a convolution after upsampling. + """ super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): + """Forward pass of the Upsample block. + + Args: + x (Tensor): Input feature map (B, C, H, W). + + Returns: + Tensor: Upsampled feature map (B, C, 2H, 2W). + """ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(yuya): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 dtype = x.dtype if dtype == torch.bfloat16: x = x.to(torch.float32) @@ -94,7 +157,17 @@ def forward(self, x): class Downsample(nn.Module): + """Downsampling block that reduces spatial resolution by a factor of 2. + + Can optionally include a convolution before downsampling. + """ + def __init__(self, in_channels, with_conv): + """ + Args: + in_channels (int): Number of input channels. + with_conv (bool): If True, apply a convolution before downsampling. + """ super().__init__() self.with_conv = with_conv if self.with_conv: @@ -102,6 +175,14 @@ def __init__(self, in_channels, with_conv): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): + """Forward pass of the Downsample block. + + Args: + x (Tensor): Input feature map (B, C, H, W). + + Returns: + Tensor: Downsampled feature map (B, C, H/2, W/2). + """ if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) @@ -112,7 +193,16 @@ def forward(self, x): class AttnBlock(nn.Module): + """Self-attention block that applies scaled dot-product attention to feature maps. + + Normalizes input, computes queries, keys, and values, then applies attention and a projection. + """ + def __init__(self, in_channels: int): + """ + Args: + in_channels (int): Number of input/output channels. + """ super().__init__() self.in_channels = in_channels @@ -124,6 +214,14 @@ def __init__(self, in_channels: int): self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: + """Compute the attention over the input feature maps. + + Args: + h_ (Tensor): Normalized input feature map (B, C, H, W). + + Returns: + Tensor: Output after applying scaled dot-product attention (B, C, H, W). + """ h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) @@ -138,11 +236,30 @@ def attention(self, h_: Tensor) -> Tensor: return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x: Tensor) -> Tensor: + """Forward pass of the AttnBlock. + + Args: + x (Tensor): Input feature map (B, C, H, W). + + Returns: + Tensor: Output feature map after self-attention (B, C, H, W). + """ return x + self.proj_out(self.attention(x)) class LinearAttention(nn.Module): + """Linear Attention block for efficient attention computations. + + Uses linear attention mechanisms to reduce complexity and memory usage. + """ + def __init__(self, dim, heads=4, dim_head=32): + """ + Args: + dim (int): Input channel dimension. + heads (int, optional): Number of attention heads. Defaults to 4. + dim_head (int, optional): Dimension per attention head. Defaults to 32. + """ super().__init__() self.heads = heads hidden_dim = dim_head * heads @@ -150,6 +267,14 @@ def __init__(self, dim, heads=4, dim_head=32): self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): + """Forward pass of the LinearAttention block. + + Args: + x (Tensor): Input feature map (B, C, H, W). + + Returns: + Tensor: Output feature map after linear attention (B, C, H, W). + """ b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) @@ -161,15 +286,27 @@ def forward(self, x): class LinAttnBlock(LinearAttention): - """ - to match AttnBlock usage - """ + """Wrapper class to provide a linear attention block in a form compatible with other attention blocks.""" def __init__(self, in_channels): + """ + Args: + in_channels (int): Number of input/output channels. + """ super().__init__(dim=in_channels, heads=1, dim_head=in_channels) def make_attn(in_channels, attn_type="vanilla"): + """Factory function to create an attention block. + + Args: + in_channels (int): Number of input/output channels. + attn_type (str, optional): Type of attention block to create. Options: "vanilla", "linear", "none". + Defaults to "vanilla". + + Returns: + nn.Module: An instance of the requested attention block. + """ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": diff --git a/nemo/collections/diffusion/vae/contperceptual_loss.py b/nemo/collections/diffusion/vae/contperceptual_loss.py new file mode 100644 index 000000000000..b63230dbbb6f --- /dev/null +++ b/nemo/collections/diffusion/vae/contperceptual_loss.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + """ + A perceptual loss module that combines LPIPS with an adversarial discriminator + for improved reconstruction quality in variational autoencoders. This class + calculates a combination of pixel-level, perceptual (LPIPS), KL, and adversarial + losses for training a VAE model with a discriminator. + """ + + def __init__( + self, + disc_start, + logvar_init=0.0, + kl_weight=1.0, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=1.0, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ): + """ + Initializes the LPIPSWithDiscriminator module. + + Args: + disc_start (int): Iteration at which to start discriminator updates. + logvar_init (float): Initial value for the log variance parameter. + kl_weight (float): Weight for the KL divergence term. + pixelloss_weight (float): Weight for the pixel-level reconstruction loss. + disc_num_layers (int): Number of layers in the discriminator. + disc_in_channels (int): Number of input channels for the discriminator. + disc_factor (float): Scaling factor for the discriminator loss. + disc_weight (float): Weight applied to the discriminator gradient balancing. + perceptual_weight (float): Weight for the LPIPS perceptual loss. + use_actnorm (bool): Whether to use actnorm in the discriminator. + disc_conditional (bool): Whether the discriminator is conditional on an additional input. + disc_loss (str): Type of GAN loss to use ("hinge" or "vanilla"). + """ + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(1) * logvar_init) + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + """ + Computes an adaptive weight that balances the reconstruction (NLL) and the + adversarial (GAN) losses. This ensures stable training by adjusting the + impact of the discriminator’s gradient on the generator. + + Args: + nll_loss (torch.Tensor): The negative log-likelihood loss. + g_loss (torch.Tensor): The generator (adversarial) loss. + last_layer (torch.nn.Parameter, optional): Last layer parameters of the model + for gradient-based calculations. If None, uses self.last_layer[0]. + + Returns: + torch.Tensor: The computed adaptive weight for balancing the discriminator. + """ + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, inputs, reconstructions, posteriors, optimizer_idx, global_step, last_layer=None, cond=None, weights=None + ): + """ + Forward pass for computing the combined loss. Depending on the optimizer index, + this either computes the generator loss (including pixel, perceptual, KL, and + adversarial terms) or the discriminator loss. + + Args: + inputs (torch.Tensor): Original inputs to reconstruct. + reconstructions (torch.Tensor): Reconstructed outputs from the model. + posteriors (object): Posteriors from the VAE model for KL computation. + optimizer_idx (int): Indicates which optimizer is being updated + (0 for generator, 1 for discriminator). + global_step (int): Current training iteration step. + last_layer (torch.nn.Parameter, optional): The last layer's parameters for + adaptive weight calculation. + cond (torch.Tensor, optional): Conditional input for the discriminator. + weights (torch.Tensor, optional): Sample-wise weighting for the losses. + + Returns: + (torch.Tensor, dict): A tuple of (loss, log_dict) where loss is the computed loss + for the current optimizer and log_dict is a dictionary of intermediate values + for logging and debugging. + """ + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = { + "total_loss": loss.clone().detach().mean(), + "logvar": self.logvar.detach().item(), + "kl_loss": kl_loss.detach().mean(), + "nll_loss": nll_loss.detach().mean(), + "rec_loss": rec_loss.detach().mean(), + "d_weight": d_weight.detach(), + "disc_factor": torch.tensor(disc_factor), + "g_loss": g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "disc_loss": d_loss.clone().detach().mean(), + "logits_real": logits_real.detach().mean(), + "logits_fake": logits_fake.detach().mean(), + } + return d_loss, log diff --git a/nemo/collections/diffusion/vae/diffusers_vae.py b/nemo/collections/diffusion/vae/diffusers_vae.py index 19a056d4a682..fe8f50ce658b 100644 --- a/nemo/collections/diffusion/vae/diffusers_vae.py +++ b/nemo/collections/diffusion/vae/diffusers_vae.py @@ -18,12 +18,44 @@ class AutoencoderKLVAE(torch.nn.Module): + """ + A class that wraps the AutoencoderKL model and provides a decode method. + + Attributes: + vae (AutoencoderKL): The underlying AutoencoderKL model loaded from a pretrained path. + """ + def __init__(self, path): + """ + Initialize the AutoencoderKLVAE instance. + + Args: + path (str): The path to the pretrained AutoencoderKL model. + """ super().__init__() self.vae = AutoencoderKL.from_pretrained(path, torch_dtype=torch.bfloat16) @torch.no_grad() def decode(self, x): + """ + Decode a latent representation using the underlying VAE model. + + This method takes a latent tensor `x` and decodes it into an image. + If `x` has a temporal dimension `T` of 1, it + rearranges the tensor before and after decoding. + + Args: + x (torch.Tensor): A tensor of shape (B, C, T, H, W), where: + B = batch size + C = number of channels + T = temporal dimension + H = height + W = width + + Returns: + torch.Tensor: Decoded image tensor with the same shape as the input (B, C, T, H, W). + """ + B, C, T, H, W = x.shape if T == 1: x = rearrange(x, 'b c t h w -> (b t) c h w') diff --git a/nemo/collections/diffusion/vae/readme.rst b/nemo/collections/diffusion/vae/readme.rst new file mode 100644 index 000000000000..ac0f2b2f5e71 --- /dev/null +++ b/nemo/collections/diffusion/vae/readme.rst @@ -0,0 +1,131 @@ +============================ +Pretraining Variational AutoEncoder +============================ + +Variational Autoencoder (VAE) is a data compression technique that compresses high-resolution images into a lower-dimensional latent space, capturing essential features while reducing dimensionality. This process allows for efficient storage and processing of image data. VAE has been integral to training Stable Diffusion (SD) models, significantly reducing computational requirements. For instance, SDLX utilizes a VAE that reduces image dimensions by 8x, greatly optimizing the training and inference processes. In this repository, we provide training codes to pretrain the VAE from scratch, enabling users to achieve higher compression ratios in the spatial dimension, such as 16x or 32x. + +Installation +============ + +Please pull the latest NeMo docker to get started, see details about NeMo docker `here `_. + +Validation +======== +We also provide a validation code for you to quickly evaluate our pretrained 16x VAE model on a 50K dataset. Once you start the docker, run the following script to start the testing. + +.. code-block:: bash + + torchrun --nproc-per-node 8 nemo/collections/diffusion/vae/validate_vae.py --yes data.path=path/to/validation/data log.log_dir=/path/to/checkpoint + +Configure the following variables: + + +1. ``data.path``: Set this to the directory containing your test data (e.g., `.jpg` or `.png` files). The original and VAE-reconstructed images will be logged side by side in Weights & Biases (wandb). + +2. ``log.log_dir``: Set this to the directory containing the pretrained checkpoint. You can find our pretrained checkpoint at ``TODO by ethan`` + +Here are some sample images generated from our pretrained VAE. + +``Left``: Original Image, ``Right``: 16x VAE Reconstructed Image + +.. list-table:: + :align: center + + * - .. image:: https://github.com/user-attachments/assets/08122f5b-2e65-4d65-87d7-eceae9d158fb + :width: 1400 + :align: center + - .. image:: https://github.com/user-attachments/assets/6e805a0d-8783-4d24-a65b-d96a6ba1555d + :width: 1400 + :align: center + - .. image:: https://github.com/user-attachments/assets/aab1ef33-35da-444d-90ee-ba3ad58a6b2d + :width: 1400 + :align: center + +Data Preparation +======== + +1. we expect data to be in the form of WebDataset tar files. If you have a folder of images, you can use `tar` to convert them into WebDataset tar files: + + .. code-block:: bash + + 000000.tar + ├── 1.jpg + ├── 2.jpg + 000001.tar + ├── 3.jpg + ├── 4.jpg + +2. next we need to index the webdataset with `energon `_. navigate to the dataset directory and run the following command: + + .. code-block:: bash + + energon prepare . --num-workers 8 --shuffle-tars + +3. then select dataset type `ImageWebdataset` and specify the type `jpg`. Below is an example of the interactive setup: + + .. code-block:: bash + + Found 2925 tar files in total. The first and last ones are: + - 000000.tar + - 002924.tar + If you want to exclude some of them, cancel with ctrl+c and specify an exclude filter in the command line. + Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1": 99,1,0 + Indexing shards [####################################] 2925/2925 + Sample 0, keys: + - jpg + Sample 1, keys: + - jpg + Found the following part types in the dataset: jpg + Do you want to create a dataset.yaml interactively? [Y/n]: + The following dataset classes are available: + 0. CaptioningWebdataset + 1. CrudeWebdataset + 2. ImageClassificationWebdataset + 3. ImageWebdataset + 4. InterleavedWebdataset + 5. MultiChoiceVQAWebdataset + 6. OCRWebdataset + 7. SimilarityInterleavedWebdataset + 8. TextWebdataset + 9. VQAOCRWebdataset + 10. VQAWebdataset + 11. VidQAWebdataset + Please enter a number to choose a class: 3 + The dataset you selected uses the following sample type: + + @dataclass + class ImageSample(Sample): + """Sample type for an image, e.g. for image reconstruction.""" + + #: The input image tensor in the shape (C, H, W) + image: torch.Tensor + + Do you want to set a simple field_map[Y] (or write your own sample_loader [n])? [Y/n]: + + For each field, please specify the corresponding name in the WebDataset. + Available types in WebDataset: jpg + Leave empty for skipping optional field + You may also access json fields e.g. by setting the field to: json[field][field] + You may also specify alternative fields e.g. by setting to: jpg,png + Please enter the field_map for ImageWebdataset: + Please enter a webdataset field name for 'image' (): + That type doesn't exist in the WebDataset. Please try again. + Please enter a webdataset field name for 'image' (): jpg + Done + +4. finally, you can use the indexed dataset to train the VAE model. specify `data.path=/path/to/dataset` in the training script `train_vae.py`. + +Training +======== + +We provide a sample training script for launching multi-node training. Simply configure ``data.path`` to point to your prepared dataset to get started. + +.. code-block:: bash + + bash nemo/collections/diffusion/vae/train_vae.sh \ + data.path=xxx + + + + + diff --git a/nemo/collections/diffusion/vae/test_autovae.py b/nemo/collections/diffusion/vae/test_autovae.py new file mode 100644 index 000000000000..b76df4ce67b7 --- /dev/null +++ b/nemo/collections/diffusion/vae/test_autovae.py @@ -0,0 +1,144 @@ +import unittest + +import torch +from autovae import VAEGenerator + + +class TestVAEGenerator(unittest.TestCase): + """Unit tests for the VAEGenerator class.""" + + def setUp(self): + # Common setup for tests + self.input_resolution = 1024 + self.compression_ratio = 8 + self.generator = VAEGenerator(input_resolution=self.input_resolution, compression_ratio=self.compression_ratio) + + def test_initialization_valid(self): + """Test that valid initialization parameters set the correct properties.""" + generator = VAEGenerator(input_resolution=1024, compression_ratio=8) + self.assertEqual(generator.input_resolution, 1024) + self.assertEqual(generator.compression_ratio, 8) + + generator = VAEGenerator(input_resolution=2048, compression_ratio=16) + self.assertEqual(generator.input_resolution, 2048) + self.assertEqual(generator.compression_ratio, 16) + + def test_initialization_invalid(self): + """Test that invalid initialization parameters raise an error.""" + with self.assertRaises(NotImplementedError): + VAEGenerator(input_resolution=4096, compression_ratio=16) + + def test_generate_input(self): + """Test that _generate_input produces a tensor with the correct shape and device.""" + input_tensor = self.generator._generate_input() + expected_shape = (1, 3, self.input_resolution, self.input_resolution) + self.assertEqual(input_tensor.shape, expected_shape) + self.assertEqual(input_tensor.dtype, torch.float16) + self.assertEqual(input_tensor.device.type, "cuda") + + def test_count_parameters(self): + """Test that _count_parameters correctly counts model parameters.""" + model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 5)) + expected_param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) + param_count = self.generator._count_parameters(model) + self.assertEqual(param_count, expected_param_count) + + def test_load_base_json_skeleton(self): + """Test that _load_base_json_skeleton returns the correct skeleton.""" + skeleton = self.generator._load_base_json_skeleton() + expected_keys = { + "_class_name", + "_diffusers_version", + "_name_or_path", + "act_fn", + "block_out_channels", + "down_block_types", + "force_upcast", + "in_channels", + "latent_channels", + "layers_per_block", + "norm_num_groups", + "out_channels", + "sample_size", + "scaling_factor", + "up_block_types", + } + self.assertEqual(set(skeleton.keys()), expected_keys) + + def test_generate_all_combinations(self): + """Test that _generate_all_combinations generates all possible combinations.""" + attr = {"layers_per_block": [1, 2], "latent_channels": [4, 8]} + combinations = self.generator._generate_all_combinations(attr) + expected_combinations = [ + {"layers_per_block": 1, "latent_channels": 4}, + {"layers_per_block": 1, "latent_channels": 8}, + {"layers_per_block": 2, "latent_channels": 4}, + {"layers_per_block": 2, "latent_channels": 8}, + ] + self.assertEqual(len(combinations), len(expected_combinations)) + for combo in expected_combinations: + self.assertIn(combo, combinations) + + def test_assign_attributes(self): + """Test that _assign_attributes correctly assigns attributes to the skeleton.""" + choice = { + "down_block_types": ["DownEncoderBlock2D"] * 4, + "up_block_types": ["UpDecoderBlock2D"] * 4, + "block_out_channels": [64, 128, 256, 512], + "layers_per_block": 2, + "latent_channels": 16, + } + skeleton = self.generator._assign_attributes(choice) + self.assertEqual(skeleton["down_block_types"], choice["down_block_types"]) + self.assertEqual(skeleton["up_block_types"], choice["up_block_types"]) + self.assertEqual(skeleton["block_out_channels"], choice["block_out_channels"]) + self.assertEqual(skeleton["layers_per_block"], choice["layers_per_block"]) + self.assertEqual(skeleton["latent_channels"], choice["latent_channels"]) + + def test_search_space_16x1024(self): + """Test that _search_space_16x1024 returns the correct search space.""" + search_space = self.generator._search_space_16x1024() + expected_keys = { + "down_block_types", + "up_block_types", + "block_out_channels", + "layers_per_block", + "latent_channels", + } + self.assertEqual(set(search_space.keys()), expected_keys) + self.assertTrue(all(isinstance(v, list) for v in search_space.values())) + + def test_sort_data_in_place(self): + """Test that _sort_data_in_place correctly sorts data based on the specified mode.""" + data = [ + {"param_diff": 10, "cuda_mem_diff": 100}, + {"param_diff": 5, "cuda_mem_diff": 50}, + {"param_diff": -3, "cuda_mem_diff": 30}, + {"param_diff": 7, "cuda_mem_diff": 70}, + ] + # Test sorting by absolute parameter difference + self.generator._sort_data_in_place(data, mode="abs_param_diff") + expected_order_param = [-3, 5, 7, 10] + actual_order_param = [item["param_diff"] for item in data] + self.assertEqual(actual_order_param, expected_order_param) + + # Test sorting by absolute CUDA memory difference + self.generator._sort_data_in_place(data, mode="abs_cuda_mem_diff") + expected_order_mem = [30, 50, 70, 100] + actual_order_mem = [item["cuda_mem_diff"] for item in data] + self.assertEqual(actual_order_mem, expected_order_mem) + + # Test sorting by mean squared error (MSE) + self.generator._sort_data_in_place(data, mode="mse") + expected_order_mse = [-3, 5, 7, 10] # Computed based on MSE values + actual_order_mse = [item["param_diff"] for item in data] + self.assertEqual(actual_order_mse, expected_order_mse) + + def test_search_for_target_vae_invalid(self): + """Test that search_for_target_vae raises an error when no budget is specified.""" + with self.assertRaises(ValueError): + self.generator.search_for_target_vae(parameters_budget=0, cuda_max_mem=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/nemo/collections/diffusion/vae/train_vae.py b/nemo/collections/diffusion/vae/train_vae.py new file mode 100644 index 000000000000..c9748407b011 --- /dev/null +++ b/nemo/collections/diffusion/vae/train_vae.py @@ -0,0 +1,365 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +import nemo_run as run +import torch +import torch.distributed +import torch.utils.checkpoint +import torchvision +import wandb +from autovae import VAEGenerator +from contperceptual_loss import LPIPSWithDiscriminator +from diffusers import AutoencoderKL +from megatron.core import parallel_state +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.energon import DefaultTaskEncoder, ImageSample +from torch import Tensor, nn + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule +from nemo.collections.diffusion.train import pretrain +from nemo.collections.llm.gpt.model.base import GPTModel +from nemo.lightning.io.mixin import IOMixin +from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT +from nemo.lightning.pytorch.optim import OptimizerModule + + +class AvgLossReduction(MegatronLossReduction): + """Performs average loss reduction across micro-batches.""" + + def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]: + """ + Forward pass for loss reduction. + + Args: + batch: The batch of data. + forward_out: The output tensor from forward computation. + + Returns: + A tuple of (loss, reduction dictionary). + """ + loss = forward_out.mean() + return loss, {"avg": loss} + + def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor: + """ + Reduce losses across multiple micro-batches by averaging them. + + Args: + losses_reduced_per_micro_batch: A sequence of loss dictionaries. + + Returns: + The averaged loss tensor. + """ + losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch]) + return losses.mean() + + +class VAE(MegatronModule): + """Variational Autoencoder (VAE) module.""" + + def __init__(self, config, pretrained_model_name_or_path, search_vae=False): + """ + Initialize the VAE model. + + Args: + config: Transformer configuration. + pretrained_model_name_or_path: Path or name of the pretrained model. + search_vae: Flag to indicate whether to search for a target VAE using AutoVAE. + """ + super().__init__(config) + if search_vae: + # Get VAE automatically from AutoVAE + self.vae = VAEGenerator(input_resolution=1024, compression_ratio=16) + # Below line is commented out due to an undefined 'generator' variable in original code snippet. + # self.vae = generator.search_for_target_vae(parameters_budget=895.178707, cuda_max_mem=0) + else: + self.vae = AutoencoderKL.from_config(pretrained_model_name_or_path, weight_dtype=torch.bfloat16) + + sdxl_vae = AutoencoderKL.from_pretrained( + 'stabilityai/stable-diffusion-xl-base-1.0', subfolder="vae", weight_dtype=torch.bfloat16 + ) + sd_dict = sdxl_vae.state_dict() + vae_dict = self.vae.state_dict() + pre_dict = {k: v for k, v in sd_dict.items() if (k in vae_dict) and (vae_dict[k].numel() == v.numel())} + self.vae.load_state_dict(pre_dict, strict=False) + del sdxl_vae + + self.vae_loss = LPIPSWithDiscriminator( + disc_start=50001, + logvar_init=0.0, + kl_weight=0.000001, + pixelloss_weight=1.0, + disc_num_layers=3, + disc_in_channels=3, + disc_factor=1.0, + disc_weight=0.5, + perceptual_weight=1.0, + use_actnorm=False, + disc_conditional=False, + disc_loss="hinge", + ) + + def forward(self, target, global_step): + """ + Forward pass through the VAE. + + Args: + target: Target images. + global_step: Current global step. + + Returns: + A tuple (aeloss, log_dict_ae, pred) containing the loss, log dictionary, and predictions. + """ + posterior = self.vae.encode(target).latent_dist + z = posterior.sample() + pred = self.vae.decode(z).sample + aeloss, log_dict_ae = self.vae_loss( + inputs=target, + reconstructions=pred, + posteriors=posterior, + optimizer_idx=0, + global_step=global_step, + last_layer=self.vae.decoder.conv_out.weight, + ) + return aeloss, log_dict_ae, pred + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """ + Set input tensor. + + Args: + input_tensor: The input tensor to the model. + """ + pass + + +class VAEModel(GPTModel): + """A GPTModel wrapper for the VAE.""" + + def __init__( + self, + pretrained_model_name_or_path: str, + optim: Optional[OptimizerModule] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + """ + Initialize the VAEModel. + + Args: + pretrained_model_name_or_path: Path or name of the pretrained model. + optim: Optional optimizer module. + model_transform: Optional function to transform the model. + """ + self.pretrained_model_name_or_path = pretrained_model_name_or_path + config = TransformerConfig(num_layers=1, hidden_size=1, num_attention_heads=1) + self.model_type = ModelType.encoder_or_decoder + super().__init__(config, optim=optim, model_transform=model_transform) + + def configure_model(self) -> None: + """Configure the model by initializing the module.""" + if not hasattr(self, "module"): + self.module = VAE(self.config, self.pretrained_model_name_or_path) + + def data_step(self, dataloader_iter) -> Dict[str, Any]: + """ + Perform a single data step to fetch a batch from the iterator. + + Args: + dataloader_iter: The dataloader iterator. + + Returns: + A dictionary with 'pixel_values' ready for the model. + """ + batch = next(dataloader_iter)[0] + return {'pixel_values': batch.image.to(device='cuda', dtype=torch.bfloat16, non_blocking=True)} + + def forward(self, *args, **kwargs): + """ + Forward pass through the underlying module. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + The result of forward pass of self.module. + """ + return self.module(*args, **kwargs) + + def training_step(self, batch, batch_idx=None) -> torch.Tensor: + """ + Perform a single training step. + + Args: + batch: The input batch. + batch_idx: Batch index. + + Returns: + The loss tensor. + """ + loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step) + + if torch.distributed.get_rank() == 0: + self.log_dict(log_dict_ae) + + return loss + + def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + """ + Perform a single validation step. + + Args: + batch: The input batch. + batch_idx: Batch index. + + Returns: + The loss tensor. + """ + loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step) + + image = torch.cat([batch["pixel_values"].cpu(), pred.cpu()], axis=0) + image = (image + 0.5).clamp(0, 1) + + # wandb is on the last rank for megatron, first rank for nemo + wandb_rank = 0 + + if parallel_state.get_data_parallel_src_rank() == wandb_rank: + if torch.distributed.get_rank() == wandb_rank: + gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())] + else: + gather_list = None + torch.distributed.gather_object( + image, gather_list, wandb_rank, group=parallel_state.get_data_parallel_group() + ) + if gather_list is not None: + self.log_dict(log_dict_ae) + wandb.log( + { + "Original (left), Reconstruction (right)": [ + wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(gather_list) + ] + }, + ) + + return loss + + @property + def training_loss_reduction(self) -> AvgLossReduction: + """Returns the loss reduction method for training.""" + if not self._training_loss_reduction: + self._training_loss_reduction = AvgLossReduction() + return self._training_loss_reduction + + @property + def validation_loss_reduction(self) -> AvgLossReduction: + """Returns the loss reduction method for validation.""" + if not self._validation_loss_reduction: + self._validation_loss_reduction = AvgLossReduction() + return self._validation_loss_reduction + + def on_validation_model_zero_grad(self) -> None: + """ + Hook to handle zero grad on validation model step. + Used here to skip first validation on resume. + """ + super().on_validation_model_zero_grad() + if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True): + self.trainer.sanity_checking = True + self._restarting_skip_val_flag = False + + +def crop_image(img, divisor=16): + """ + Crop the image so that both dimensions are divisible by the given divisor. + + Args: + img: Image tensor. + divisor: The divisor to use for cropping. + + Returns: + The cropped image tensor. + """ + h, w = img.shape[-2], img.shape[-1] + + delta_h = h % divisor + delta_w = w % divisor + + delta_h_top = delta_h // 2 + delta_h_bottom = delta_h - delta_h_top + + delta_w_left = delta_w // 2 + delta_w_right = delta_w - delta_w_left + + img_cropped = img[..., delta_h_top : h - delta_h_bottom, delta_w_left : w - delta_w_right] + + return img_cropped + + +class ImageTaskEncoder(DefaultTaskEncoder, IOMixin): + """Image task encoder that crops and normalizes the image.""" + + def encode_sample(self, sample: ImageSample) -> ImageSample: + """ + Encode a single image sample by cropping and shifting its values. + + Args: + sample: An image sample. + + Returns: + The transformed image sample. + """ + sample = super().encode_sample(sample) + sample.image = crop_image(sample.image, 16) + sample.image -= 0.5 + return sample + + +@run.cli.factory(target=llm.train) +def train_vae() -> run.Partial: + """ + Training factory function for VAE. + + Returns: + A run.Partial recipe for training. + """ + recipe = pretrain() + recipe.model = run.Config( + VAEModel, + pretrained_model_name_or_path='nemo/collections/diffusion/vae/vae16x/config.json', + ) + recipe.data = run.Config( + DiffusionDataModule, + task_encoder=run.Config(ImageTaskEncoder), + global_batch_size=24, + num_workers=10, + ) + recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9) + recipe.optim.config.lr = 5e-6 + recipe.optim.config.weight_decay = 1e-2 + recipe.log.log_dir = 'nemo_experiments/train_vae' + recipe.trainer.val_check_interval = 1000 + recipe.trainer.callbacks[0].every_n_train_steps = 1000 + + return recipe + + +if __name__ == "__main__": + run.cli.main(llm.train, default_factory=train_vae) diff --git a/nemo/collections/diffusion/vae/train_vae.sh b/nemo/collections/diffusion/vae/train_vae.sh new file mode 100644 index 000000000000..3f5a46ab9f65 --- /dev/null +++ b/nemo/collections/diffusion/vae/train_vae.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +#SBATCH -p batch -A coreai_dlalgo_llm -t 4:00:00 --nodes=16 --exclusive --mem=0 --overcommit --gpus-per-node 8 --ntasks-per-node=8 --dependency=singleton + +export WANDB_RESUME=allow +export WANDB_NAME=train_vae + +DIR=`pwd` + +srun --signal=TERM@300 -l --container-image ${IMAGE} --container-mounts "/lustre:/lustre/,/home:/home" --no-container-mount-home --mpi=pmix bash -c "cd ${DIR} ; python -u nemo/collections/diffusion/vae/train_vae.py --yes $*" diff --git a/nemo/collections/diffusion/vae/vae16x/config.json b/nemo/collections/diffusion/vae/vae16x/config.json new file mode 100644 index 000000000000..9b363564eed2 --- /dev/null +++ b/nemo/collections/diffusion/vae/vae16x/config.json @@ -0,0 +1,35 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.20.0.dev0", + "_name_or_path": "../sdxl-vae/", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 1024, + 2048 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 16, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 1024, + "scaling_factor": 0.13025, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ] + } diff --git a/nemo/collections/diffusion/vae/validate_vae.py b/nemo/collections/diffusion/vae/validate_vae.py new file mode 100644 index 000000000000..dd143d9e0b33 --- /dev/null +++ b/nemo/collections/diffusion/vae/validate_vae.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run +from nemo.collections import llm +from nemo.collections.diffusion.vae.train_vae import train_vae + + +@run.cli.factory(target=llm.validate) +def validate_vae() -> run.Partial: + """ + Create a partial function for validating a VAE (Variational Autoencoder) model. + + This function uses the training recipe defined in `train_vae()` to set up + the model, data, trainer, logging, and optimization configurations for + validation. It returns a Partial object that can be used by the NeMo run CLI + to execute the validation procedure on the provided model and data. + + Returns: + run.Partial: A partial object configured with llm.validate target + and all necessary arguments extracted from the VAE training recipe. + """ + recipe = train_vae() + return run.Partial( + llm.validate, + model=recipe.model, + data=recipe.data, + trainer=recipe.trainer, + log=recipe.log, + optim=recipe.optim, + tokenizer=None, + resume=recipe.resume, + model_transform=None, + ) + + +if __name__ == "__main__": + run.cli.main(llm.validate, default_factory=validate_vae) From 937b2fff5645a6949f1751f2d4c8d7035c7084f9 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Fri, 13 Dec 2024 19:26:38 -0800 Subject: [PATCH 2/2] vae training Signed-off-by: linnan wang --- nemo/collections/diffusion/vae/autovae.py | 14 ++++++++++++++ .../diffusion/vae/contperceptual_loss.py | 14 ++++++++++++++ nemo/collections/diffusion/vae/test_autovae.py | 14 ++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/nemo/collections/diffusion/vae/autovae.py b/nemo/collections/diffusion/vae/autovae.py index a7642886ca6c..0797036f9cc0 100644 --- a/nemo/collections/diffusion/vae/autovae.py +++ b/nemo/collections/diffusion/vae/autovae.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import itertools import time from typing import Dict, List diff --git a/nemo/collections/diffusion/vae/contperceptual_loss.py b/nemo/collections/diffusion/vae/contperceptual_loss.py index b63230dbbb6f..7021e31f7f3b 100644 --- a/nemo/collections/diffusion/vae/contperceptual_loss.py +++ b/nemo/collections/diffusion/vae/contperceptual_loss.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn as nn diff --git a/nemo/collections/diffusion/vae/test_autovae.py b/nemo/collections/diffusion/vae/test_autovae.py index b76df4ce67b7..fa414c20c4ce 100644 --- a/nemo/collections/diffusion/vae/test_autovae.py +++ b/nemo/collections/diffusion/vae/test_autovae.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import torch