Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vae training and autovae #11592

Merged
merged 5 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 217 additions & 8 deletions nemo/collections/diffusion/vae/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,45 @@
import torch
from torch import Tensor, nn

from nemo.collections.diffusion.vae.blocks import AttnBlock, Downsample, Normalize, ResnetBlock, Upsample, make_attn
from nemo.collections.diffusion.vae.blocks import Downsample, Normalize, ResnetBlock, Upsample, make_attn


@dataclass
class AutoEncoderParams:
"""Dataclass for storing autoencoder hyperparameters.

Attributes
----------
ch_mult : list[int]
Channel multipliers at each resolution level.
attn_resolutions : list[int]
List of resolutions at which attention layers are applied.
resolution : int, optional
Input image resolution. Default is 256.
in_channels : int, optional
Number of input channels. Default is 3.
ch : int, optional
Base channel dimension. Default is 128.
out_ch : int, optional
Number of output channels. Default is 3.
num_res_blocks : int, optional
Number of residual blocks at each resolution. Default is 2.
z_channels : int, optional
Number of latent channels in the compressed representation. Default is 16.
scale_factor : float, optional
Scaling factor for latent representations. Default is 0.3611.
shift_factor : float, optional
Shift factor for latent representations. Default is 0.1159.
attn_type : str, optional
Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'.
double_z : bool, optional
If True, produce both mean and log-variance for latent space. Default is True.
dropout : float, optional
Dropout rate. Default is 0.0.
ckpt : str or None, optional
Path to checkpoint file for loading pretrained weights. Default is None.
"""

ch_mult: list[int]
attn_resolutions: list[int]
resolution: int = 256
Expand All @@ -39,12 +73,55 @@ class AutoEncoderParams:
ckpt: str = None


def nonlinearity(x):
# swish
def nonlinearity(x: Tensor) -> Tensor:
"""Applies the SiLU (Swish) nonlinearity.

Parameters
----------
x : torch.Tensor
Input tensor.

Returns
-------
torch.Tensor
Transformed tensor after applying SiLU activation.
"""
return torch.nn.functional.silu(x)


class Encoder(nn.Module):
"""Encoder module that downsamples and encodes input images into a latent representation.

Parameters
----------
ch : int
Base channel dimension.
out_ch : int
Number of output channels.
ch_mult : list[int]
Channel multipliers at each resolution level.
num_res_blocks : int
Number of residual blocks at each resolution level.
attn_resolutions : list[int]
List of resolutions at which attention layers are applied.
in_channels : int
Number of input image channels.
resolution : int
Input image resolution.
z_channels : int
Number of latent channels.
dropout : float, optional
Dropout rate. Default is 0.0.
resamp_with_conv : bool, optional
Whether to use convolutional resampling. Default is True.
double_z : bool, optional
If True, produce mean and log-variance channels for latent space. Default is True.
use_linear_attn : bool, optional
If True, use linear attention. Default is False.
attn_type : str, optional
Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -117,7 +194,19 @@ def __init__(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the Encoder.

Parameters
----------
x : torch.Tensor
Input image tensor of shape (B, C, H, W).

Returns
-------
torch.Tensor
Latent representation before sampling, with shape (B, 2*z_channels, H', W') if double_z=True.
"""
# timestep embedding
temb = None

Expand Down Expand Up @@ -146,6 +235,40 @@ def forward(self, x):


class Decoder(nn.Module):
"""Decoder module that upscales and decodes latent representations back into images.

Parameters
----------
ch : int
Base channel dimension.
out_ch : int
Number of output channels (e.g. 3 for RGB).
ch_mult : list[int]
Channel multipliers at each resolution level.
num_res_blocks : int
Number of residual blocks at each resolution level.
attn_resolutions : list[int]
List of resolutions at which attention layers are applied.
in_channels : int
Number of input image channels.
resolution : int
Input image resolution.
z_channels : int
Number of latent channels.
dropout : float, optional
Dropout rate. Default is 0.0.
resamp_with_conv : bool, optional
Whether to use convolutional resampling. Default is True.
give_pre_end : bool, optional
If True, returns the tensor before the final normalization and convolution. Default is False.
tanh_out : bool, optional
If True, applies a tanh activation to the output. Default is False.
use_linear_attn : bool, optional
If True, use linear attention. Default is False.
attn_type : str, optional
Type of attention to use ('vanilla', 'linear'). Default is 'vanilla'.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -224,8 +347,19 @@ def __init__(
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
def forward(self, z: Tensor) -> Tensor:
"""Forward pass of the Decoder.

Parameters
----------
z : torch.Tensor
Latent representation of shape (B, z_channels, H', W').

Returns
-------
torch.Tensor
Decoded image of shape (B, out_ch, H, W).
"""
self.last_z_shape = z.shape

# timestep embedding
Expand Down Expand Up @@ -261,12 +395,35 @@ def forward(self, z):


class DiagonalGaussian(nn.Module):
"""Module that splits an input tensor into mean and log-variance and optionally samples from the Gaussian.

Parameters
----------
sample : bool, optional
If True, return a sample from the Gaussian. Otherwise, return the mean. Default is True.
chunk_dim : int, optional
Dimension along which to chunk the tensor into mean and log-variance. Default is 1.
"""

def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim

def forward(self, z: Tensor) -> Tensor:
"""Forward pass of the DiagonalGaussian module.

Parameters
----------
z : torch.Tensor
Input tensor of shape (..., 2*z_channels, ...).

Returns
-------
torch.Tensor
If sample=True, returns a sampled tensor from N(mean, var).
If sample=False, returns the mean.
"""
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
Expand All @@ -276,6 +433,14 @@ def forward(self, z: Tensor) -> Tensor:


class AutoEncoder(nn.Module):
"""Full AutoEncoder model combining an Encoder, Decoder, and latent Gaussian sampling.

Parameters
----------
params : AutoEncoderParams
Configuration parameters for the AutoEncoder model.
"""

def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
Expand Down Expand Up @@ -314,21 +479,65 @@ def __init__(self, params: AutoEncoderParams):
self.load_from_checkpoint(params.ckpt)

def encode(self, x: Tensor) -> Tensor:
"""Encode an input image to its latent representation.

Parameters
----------
x : torch.Tensor
Input image of shape (B, C, H, W).

Returns
-------
torch.Tensor
Latent representation of the input image.
"""
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z

def decode(self, z: Tensor) -> Tensor:
"""Decode a latent representation back into an image.

Parameters
----------
z : torch.Tensor
Latent representation of shape (B, z_channels, H', W').

Returns
-------
torch.Tensor
Reconstructed image of shape (B, out_ch, H, W).
"""
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass that encodes and decodes the input image.

Parameters
----------
x : torch.Tensor
Input image tensor.

Returns
-------
torch.Tensor
Reconstructed image.
"""
return self.decode(self.encode(x))

def load_from_checkpoint(self, ckpt_path):
def load_from_checkpoint(self, ckpt_path: str):
"""Load the autoencoder weights from a checkpoint file.

Parameters
----------
ckpt_path : str
Path to the checkpoint file.
"""
from safetensors.torch import load_file as load_sft

state_dict = load_sft(ckpt_path)
missing, unexpected = self.load_state_dict(state_dict)
if len(missing) > 0:
logger.warning(f"Following keys are missing from checkpoint loaded: {missing}")
# If logger is not defined, you may replace this with print or similar.
print(f"Warning: Following keys are missing from checkpoint loaded: {missing}")
Loading
Loading