Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] the input variables may be changed to scalars when use activation checkpoint #6969

Open
zhangvia opened this issue Jan 23, 2025 · 0 comments
Labels
bug Something isn't working training

Comments

@zhangvia
Copy link

zhangvia commented Jan 23, 2025

Describe the bug
i'm using deepspeed activation checkpoint to train the hunyuanvideo. HunyuanVideo model is composed of a series of mmdit blocks. and every block has same inputs, and every block will return hiddenstates ,encoder_hidden_states which will be the input of next block. the other input variables won't be modified and will be passed to next block directly. but weirdly, when i use deepspeed checkpoint function to wrap every block. some of the inputs will be modified after block forward. if i change to torch.utils.checkpoint, the error disappeared

To Reproduce
here is a simple script:

import torch
import torch.nn as nn
from typing import Optional, Tuple

from deepspeed.runtime.activation_checkpointing import checkpointing
from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel

from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoSingleTransformerBlock,HunyuanVideoTransformerBlock
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)
from functools import partial
from accelerate import Accelerator
from peft import LoraConfig
def hack_transformer(transformer):
    with torch.no_grad():
        in_chans = 32
        embed_dim = 3072
        patch_size = (1,2,2)
        new_conv_in = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        torch.nn.init.kaiming_normal_(new_conv_in.weight)  
        new_conv_in.weight.data = new_conv_in.weight.data * 0.  
        new_conv_in.weight.data[:, :16] = transformer.x_embedder.proj.weight.data  
        new_conv_in.bias.data[:] = transformer.x_embedder.proj.bias.data 
        transformer.x_embedder.proj = new_conv_in
        transformer.config.in_channels = in_chans
    return transformer

def apply_selective_checkpointing(model, block_types, p, use_deepspeed_ac):
    '''
    block_types: a list of nn.Module types to be checkpointed
    p: the fraction of the all blocks to be checkpointed
    '''
    block_idx = 0
    cut_off = 1 / 2
    # when passing p as a fraction number (e.g. 1/3), it will be interpreted
    # as a string in argv, thus we need eval("1/3") here for fractions.
    p = eval(p) if isinstance(p, str) else p

    def selective_checkpointing(submodule):
        nonlocal block_idx
        nonlocal cut_off

        if isinstance(submodule, block_types):
            block_idx += 1
            if block_idx * p >= cut_off:
                cut_off += 1

                return True
        return False
    
    def count_total_blocks(model, block_types):
        total_blocks = 0

        def count_blocks(module):
            nonlocal total_blocks
            if isinstance(module, block_types):
                total_blocks += 1

        model.apply(count_blocks)
        return total_blocks
    
    if use_deepspeed_ac:
        from deepspeed.runtime.activation_checkpointing import checkpointing
        total_block_num = count_total_blocks(model, block_types)
        num_checkpoints = round(p * total_block_num)
        checkpointing.configure(
            mpu_=None,
            deepspeed_config=None,
            partition_activations=False,
            contiguous_checkpointing=False,
            num_checkpoints=num_checkpoints, 
            checkpoint_in_cpu=True, 
            synchronize=False,
            profile=True,
        )
        checkpoint_fn = checkpointing.checkpoint
        checkpointing_wrapper = partial(checkpoint_wrapper, checkpoint_fn=checkpoint_fn)
    else:
        checkpointing_wrapper = partial(checkpoint_wrapper,checkpoint_impl=CheckpointImpl.NO_REENTRANT)

    apply_activation_checkpointing(
        model,
        checkpoint_wrapper_fn=checkpointing_wrapper,
        check_fn=selective_checkpointing,
    )


accelerator = Accelerator(
        mixed_precision='fp16',
    )
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 1

model = HunyuanVideoTransformer3DModel()
model = hack_transformer(model)
target_modules = [
            'timestep_embedder.linear_1',
            'timestep_embedder.linear_2',
            'text_embedder.linear_1',
            'text_embedder.linear_2',
            'guidance_embedder.linear_1',
            'guidance_embedder.linear_1',
    #        'norm1',
    #        'context_embedder',
            'norm_out.linear',
            'norm1_context.linear'
            'proj_out',
            'norm1.linear'
            'attn.to_q',
            'attn.to_k',
            'attn.to_v',
            "attn.to_out.0",
            "attn.add_k_proj",
            "attn.add_q_proj",
            "attn.add_v_proj",
            'attn.to_add_out',
            'ff.net.0.proj',
            'ff.net.2',
            'ff_context.net.0.proj',
            "ff_context.net.2",
            'norm.linear',
            'proj_mlp',
        ]
        # now we will add new LoRA weights the transformer layers
rank=256
transformer_lora_config = LoraConfig(
    r=rank,
    lora_alpha=rank,
    init_lora_weights="gaussian",
    target_modules=target_modules,
)
model.add_adapter(transformer_lora_config)

    
for param_name, param in model.named_parameters():
    if "lora"  in param_name or "dora_scale" in param_name:
        param.requires_grad = True
    if 'x_embedder' in param_name:
        param.requires_grad = True
        print('train x_embedder')
params_to_opt = list(filter(lambda p: p.requires_grad, model.parameters()))

optimizer = torch.optim.AdamW(
    params_to_opt,
    lr=1e-3
)

apply_selective_checkpointing(model, (HunyuanVideoSingleTransformerBlock,HunyuanVideoTransformerBlock), "1", True)
model.train()
model,optimizer = accelerator.prepare(model,optimizer)


for i in range(10):
    hidden_states = torch.rand(1, 32,17,128, 72).to(device=accelerator.device,dtype=torch.float16)
    encoder_hidden_states = torch.rand(1, 256, 4096).to(device=accelerator.device,dtype=torch.float16)
    encoder_attention_mask = torch.rand(1, 256).to(device=accelerator.device,dtype=torch.float16)
    pooled_projections = torch.rand(1, 768).to(device=accelerator.device,dtype=torch.float16)
    temb = torch.tensor([777]).to(device=accelerator.device)
    guidance = torch.tensor([0.7], device=accelerator.device)
    output=model(hidden_states=hidden_states,
                    timestep=temb,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    pooled_projections=pooled_projections,
                    guidance=guidance
                    )

Expected behavior
the hunyuanvideo model got large intermediate activatetions which could cause oom even on h20. i want to use deepspeed activations checkpoint to reduce the vram cost

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
 [WARNING]  FP Quantizer is using an untested triton version (3.1.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
/media/sda/zjy/software/miniconda3/envs/stable-fast/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'
/media/sda/zjy/software/miniconda3/envs/stable-fast/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_release@CXXABI_1.3'
/media/sda/zjy/software/miniconda3/envs/stable-fast/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_throw@CXXABI_1.3'
/media/sda/zjy/software/miniconda3/envs/stable-fast/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base*)@GLIBCXX_3.4'
/media/sda/zjy/software/miniconda3/envs/stable-fast/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::~basic_filebuf()@GLIBCXX_3.4'
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.5
 [WARNING]  using untested triton version (3.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/media/sda/zjy/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch']
torch version .................... 2.5.1+cu118
deepspeed install path ........... ['/media/sda/zjy/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.16.3, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.5, cuda 11.8
shared memory (/dev/shm) size .... 503.76 GB

Screenshots

Image

i print the temb before the every block forward. and you can see the temb is a 1d array before the first block forward. and after the first block forward it became a scalar. the block doesn't has any in-place operations. if i change the checkpoint function to torch.utils.checkpoint, the error gone.

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types [e.g. two machines with x8 A100s each]
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?

Docker context
Are you using a specific docker image that you can share?

Additional context
Add any other context about the problem here.

@zhangvia zhangvia added bug Something isn't working training labels Jan 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant