Skip to content

Commit

Permalink
vae training and autovae (#11592)
Browse files Browse the repository at this point in the history
* vae training

Signed-off-by: linnan wang <[email protected]>

* vae training

Signed-off-by: linnan wang <[email protected]>

---------

Signed-off-by: linnan wang <[email protected]>
Signed-off-by: Abhinav Garg <[email protected]>
  • Loading branch information
linnanwang authored and abhinavg4 committed Jan 30, 2025
1 parent 87fcc87 commit 7e9cdaf
Show file tree
Hide file tree
Showing 11 changed files with 1,655 additions and 13 deletions.
225 changes: 217 additions & 8 deletions nemo/collections/diffusion/vae/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,45 @@
import torch
from torch import Tensor, nn

from nemo.collections.diffusion.vae.blocks import AttnBlock, Downsample, Normalize, ResnetBlock, Upsample, make_attn
from nemo.collections.diffusion.vae.blocks import Downsample, Normalize, ResnetBlock, Upsample, make_attn


@dataclass
class AutoEncoderParams:
"""Dataclass for storing autoencoder hyperparameters.
Attributes
----------
ch_mult : list[int]
Channel multipliers at each resolution level.
attn_resolutions : list[int]
List of resolutions at which attention layers are applied.
resolution : int, optional
Input image resolution. Default is 256.
in_channels : int, optional
Number of input channels. Default is 3.
ch : int, optional
Base channel dimension. Default is 128.
out_ch : int, optional
Number of output channels. Default is 3.
num_res_blocks : int, optional
Number of residual blocks at each resolution. Default is 2.
z_channels : int, optional
Number of latent channels in the compressed representation. Default is 16.
scale_factor : float, optional
Scaling factor for latent representations. Default is 0.3611.
shift_factor : float, optional
Shift factor for latent representations. Default is 0.1159.
attn_type : str, optional
Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'.
double_z : bool, optional
If True, produce both mean and log-variance for latent space. Default is True.
dropout : float, optional
Dropout rate. Default is 0.0.
ckpt : str or None, optional
Path to checkpoint file for loading pretrained weights. Default is None.
"""

ch_mult: list[int]
attn_resolutions: list[int]
resolution: int = 256
Expand All @@ -39,12 +73,55 @@ class AutoEncoderParams:
ckpt: str = None


def nonlinearity(x):
# swish
def nonlinearity(x: Tensor) -> Tensor:
"""Applies the SiLU (Swish) nonlinearity.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Transformed tensor after applying SiLU activation.
"""
return torch.nn.functional.silu(x)


class Encoder(nn.Module):
"""Encoder module that downsamples and encodes input images into a latent representation.
Parameters
----------
ch : int
Base channel dimension.
out_ch : int
Number of output channels.
ch_mult : list[int]
Channel multipliers at each resolution level.
num_res_blocks : int
Number of residual blocks at each resolution level.
attn_resolutions : list[int]
List of resolutions at which attention layers are applied.
in_channels : int
Number of input image channels.
resolution : int
Input image resolution.
z_channels : int
Number of latent channels.
dropout : float, optional
Dropout rate. Default is 0.0.
resamp_with_conv : bool, optional
Whether to use convolutional resampling. Default is True.
double_z : bool, optional
If True, produce mean and log-variance channels for latent space. Default is True.
use_linear_attn : bool, optional
If True, use linear attention. Default is False.
attn_type : str, optional
Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -117,7 +194,19 @@ def __init__(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the Encoder.
Parameters
----------
x : torch.Tensor
Input image tensor of shape (B, C, H, W).
Returns
-------
torch.Tensor
Latent representation before sampling, with shape (B, 2*z_channels, H', W') if double_z=True.
"""
# timestep embedding
temb = None

Expand Down Expand Up @@ -146,6 +235,40 @@ def forward(self, x):


class Decoder(nn.Module):
"""Decoder module that upscales and decodes latent representations back into images.
Parameters
----------
ch : int
Base channel dimension.
out_ch : int
Number of output channels (e.g. 3 for RGB).
ch_mult : list[int]
Channel multipliers at each resolution level.
num_res_blocks : int
Number of residual blocks at each resolution level.
attn_resolutions : list[int]
List of resolutions at which attention layers are applied.
in_channels : int
Number of input image channels.
resolution : int
Input image resolution.
z_channels : int
Number of latent channels.
dropout : float, optional
Dropout rate. Default is 0.0.
resamp_with_conv : bool, optional
Whether to use convolutional resampling. Default is True.
give_pre_end : bool, optional
If True, returns the tensor before the final normalization and convolution. Default is False.
tanh_out : bool, optional
If True, applies a tanh activation to the output. Default is False.
use_linear_attn : bool, optional
If True, use linear attention. Default is False.
attn_type : str, optional
Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -224,8 +347,19 @@ def __init__(
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
def forward(self, z: Tensor) -> Tensor:
"""Forward pass of the Decoder.
Parameters
----------
z : torch.Tensor
Latent representation of shape (B, z_channels, H', W').
Returns
-------
torch.Tensor
Decoded image of shape (B, out_ch, H, W).
"""
self.last_z_shape = z.shape

# timestep embedding
Expand Down Expand Up @@ -261,12 +395,35 @@ def forward(self, z):


class DiagonalGaussian(nn.Module):
"""Module that splits an input tensor into mean and log-variance and optionally samples from the Gaussian.
Parameters
----------
sample : bool, optional
If True, return a sample from the Gaussian. Otherwise, return the mean. Default is True.
chunk_dim : int, optional
Dimension along which to chunk the tensor into mean and log-variance. Default is 1.
"""

def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim

def forward(self, z: Tensor) -> Tensor:
"""Forward pass of the DiagonalGaussian module.
Parameters
----------
z : torch.Tensor
Input tensor of shape (..., 2*z_channels, ...).
Returns
-------
torch.Tensor
If sample=True, returns a sampled tensor from N(mean, var).
If sample=False, returns the mean.
"""
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
Expand All @@ -276,6 +433,14 @@ def forward(self, z: Tensor) -> Tensor:


class AutoEncoder(nn.Module):
"""Full AutoEncoder model combining an Encoder, Decoder, and latent Gaussian sampling.
Parameters
----------
params : AutoEncoderParams
Configuration parameters for the AutoEncoder model.
"""

def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
Expand Down Expand Up @@ -314,21 +479,65 @@ def __init__(self, params: AutoEncoderParams):
self.load_from_checkpoint(params.ckpt)

def encode(self, x: Tensor) -> Tensor:
"""Encode an input image to its latent representation.
Parameters
----------
x : torch.Tensor
Input image of shape (B, C, H, W).
Returns
-------
torch.Tensor
Latent representation of the input image.
"""
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z

def decode(self, z: Tensor) -> Tensor:
"""Decode a latent representation back into an image.
Parameters
----------
z : torch.Tensor
Latent representation of shape (B, z_channels, H', W').
Returns
-------
torch.Tensor
Reconstructed image of shape (B, out_ch, H, W).
"""
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass that encodes and decodes the input image.
Parameters
----------
x : torch.Tensor
Input image tensor.
Returns
-------
torch.Tensor
Reconstructed image.
"""
return self.decode(self.encode(x))

def load_from_checkpoint(self, ckpt_path):
def load_from_checkpoint(self, ckpt_path: str):
"""Load the autoencoder weights from a checkpoint file.
Parameters
----------
ckpt_path : str
Path to the checkpoint file.
"""
from safetensors.torch import load_file as load_sft

state_dict = load_sft(ckpt_path)
missing, unexpected = self.load_state_dict(state_dict)
if len(missing) > 0:
logger.warning(f"Following keys are missing from checkpoint loaded: {missing}")
# If logger is not defined, you may replace this with print or similar.
print(f"Warning: Following keys are missing from checkpoint loaded: {missing}")
Loading

0 comments on commit 7e9cdaf

Please sign in to comment.