You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
i'm using deepspeed activation checkpoint to train the hunyuanvideo. HunyuanVideo model is composed of a series of mmdit blocks. and every block has same inputs, and every block will return hiddenstates ,encoder_hidden_states which will be the input of next block. the other input variables won't be modified and will be passed to next block directly. but weirdly, when i use deepspeed checkpoint function to wrap every block. some of the inputs will be modified after block forward. if i change to torch.utils.checkpoint, the error disappeared
To Reproduce
here is a simple script:
importtorchimporttorch.nnasnnfromtypingimportOptional, Tuplefromdeepspeed.runtime.activation_checkpointingimportcheckpointingfromdiffusers.modelsimportAutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModelfromdiffusers.models.transformers.transformer_hunyuan_videoimportHunyuanVideoSingleTransformerBlock,HunyuanVideoTransformerBlockfromtorch.distributed.algorithms._checkpoint.checkpoint_wrapperimport (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
fromfunctoolsimportpartialfromaccelerateimportAcceleratorfrompeftimportLoraConfigdefhack_transformer(transformer):
withtorch.no_grad():
in_chans=32embed_dim=3072patch_size= (1,2,2)
new_conv_in=nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
torch.nn.init.kaiming_normal_(new_conv_in.weight)
new_conv_in.weight.data=new_conv_in.weight.data*0.new_conv_in.weight.data[:, :16] =transformer.x_embedder.proj.weight.datanew_conv_in.bias.data[:] =transformer.x_embedder.proj.bias.datatransformer.x_embedder.proj=new_conv_intransformer.config.in_channels=in_chansreturntransformerdefapply_selective_checkpointing(model, block_types, p, use_deepspeed_ac):
''' block_types: a list of nn.Module types to be checkpointed p: the fraction of the all blocks to be checkpointed '''block_idx=0cut_off=1/2# when passing p as a fraction number (e.g. 1/3), it will be interpreted# as a string in argv, thus we need eval("1/3") here for fractions.p=eval(p) ifisinstance(p, str) elsepdefselective_checkpointing(submodule):
nonlocalblock_idxnonlocalcut_offifisinstance(submodule, block_types):
block_idx+=1ifblock_idx*p>=cut_off:
cut_off+=1returnTruereturnFalsedefcount_total_blocks(model, block_types):
total_blocks=0defcount_blocks(module):
nonlocaltotal_blocksifisinstance(module, block_types):
total_blocks+=1model.apply(count_blocks)
returntotal_blocksifuse_deepspeed_ac:
fromdeepspeed.runtime.activation_checkpointingimportcheckpointingtotal_block_num=count_total_blocks(model, block_types)
num_checkpoints=round(p*total_block_num)
checkpointing.configure(
mpu_=None,
deepspeed_config=None,
partition_activations=False,
contiguous_checkpointing=False,
num_checkpoints=num_checkpoints,
checkpoint_in_cpu=True,
synchronize=False,
profile=True,
)
checkpoint_fn=checkpointing.checkpointcheckpointing_wrapper=partial(checkpoint_wrapper, checkpoint_fn=checkpoint_fn)
else:
checkpointing_wrapper=partial(checkpoint_wrapper,checkpoint_impl=CheckpointImpl.NO_REENTRANT)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpointing_wrapper,
check_fn=selective_checkpointing,
)
accelerator=Accelerator(
mixed_precision='fp16',
)
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] =1model=HunyuanVideoTransformer3DModel()
model=hack_transformer(model)
target_modules= [
'timestep_embedder.linear_1',
'timestep_embedder.linear_2',
'text_embedder.linear_1',
'text_embedder.linear_2',
'guidance_embedder.linear_1',
'guidance_embedder.linear_1',
# 'norm1',# 'context_embedder','norm_out.linear',
'norm1_context.linear''proj_out',
'norm1.linear''attn.to_q',
'attn.to_k',
'attn.to_v',
"attn.to_out.0",
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
'attn.to_add_out',
'ff.net.0.proj',
'ff.net.2',
'ff_context.net.0.proj',
"ff_context.net.2",
'norm.linear',
'proj_mlp',
]
# now we will add new LoRA weights the transformer layersrank=256transformer_lora_config=LoraConfig(
r=rank,
lora_alpha=rank,
init_lora_weights="gaussian",
target_modules=target_modules,
)
model.add_adapter(transformer_lora_config)
forparam_name, paraminmodel.named_parameters():
if"lora"inparam_nameor"dora_scale"inparam_name:
param.requires_grad=Trueif'x_embedder'inparam_name:
param.requires_grad=Trueprint('train x_embedder')
params_to_opt=list(filter(lambdap: p.requires_grad, model.parameters()))
optimizer=torch.optim.AdamW(
params_to_opt,
lr=1e-3
)
apply_selective_checkpointing(model, (HunyuanVideoSingleTransformerBlock,HunyuanVideoTransformerBlock), "1", True)
model.train()
model,optimizer=accelerator.prepare(model,optimizer)
foriinrange(10):
hidden_states=torch.rand(1, 32,17,128, 72).to(device=accelerator.device,dtype=torch.float16)
encoder_hidden_states=torch.rand(1, 256, 4096).to(device=accelerator.device,dtype=torch.float16)
encoder_attention_mask=torch.rand(1, 256).to(device=accelerator.device,dtype=torch.float16)
pooled_projections=torch.rand(1, 768).to(device=accelerator.device,dtype=torch.float16)
temb=torch.tensor([777]).to(device=accelerator.device)
guidance=torch.tensor([0.7], device=accelerator.device)
output=model(hidden_states=hidden_states,
timestep=temb,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
pooled_projections=pooled_projections,
guidance=guidance
)
Expected behavior
the hunyuanvideo model got large intermediate activatetions which could cause oom even on h20. i want to use deepspeed activations checkpoint to reduce the vram cost
i print the temb before the every block forward. and you can see the temb is a 1d array before the first block forward. and after the first block forward it became a scalar. the block doesn't has any in-place operations. if i change the checkpoint function to torch.utils.checkpoint, the error gone.
System info (please complete the following information):
OS: [e.g. Ubuntu 18.04]
GPU count and types [e.g. two machines with x8 A100s each]
Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
Python version
Any other relevant info about your setup
Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered:
Describe the bug
i'm using deepspeed activation checkpoint to train the hunyuanvideo. HunyuanVideo model is composed of a series of mmdit blocks. and every block has same inputs, and every block will return hiddenstates ,encoder_hidden_states which will be the input of next block. the other input variables won't be modified and will be passed to next block directly. but weirdly, when i use deepspeed checkpoint function to wrap every block. some of the inputs will be modified after block forward. if i change to torch.utils.checkpoint, the error disappeared
To Reproduce
here is a simple script:
Expected behavior
the hunyuanvideo model got large intermediate activatetions which could cause oom even on h20. i want to use deepspeed activations checkpoint to reduce the vram cost
ds_report output
Screenshots
i print the temb before the every block forward. and you can see the temb is a 1d array before the first block forward. and after the first block forward it became a scalar. the block doesn't has any in-place operations. if i change the checkpoint function to torch.utils.checkpoint, the error gone.
System info (please complete the following information):
Launcher context
Are you launching your experiment with the
deepspeed
launcher, MPI, or something else?Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: