From abb0143491f3aa45e395a64f8ed5383b3a73e5cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 21 Oct 2024 18:50:31 +0800 Subject: [PATCH 01/21] weight parallel --- ..._to_image_generation-stable_diffusion_3.py | 22 ++++++---- .../ppdiffusers/models/simplified_sd3.py | 11 +++-- .../ppdiffusers/models/transformer_sd3.py | 27 ++++++++++++ .../pipeline_stable_diffusion_3.py | 44 +++++++++---------- 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index fd1784f8f..cb7a8017b 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -48,7 +48,7 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" - os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" + # os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" if args.inference_optimize_bp: os.environ["INFERENCE_OPTIMIZE_BP"] = "True" if args.dtype == "float32": @@ -76,6 +76,10 @@ def parse_args(): hcg = fleet.get_hybrid_communicate_group() mp_id = hcg.get_model_parallel_rank() rank_id = dist.get_rank() + if rank_id==0: + os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2_2" + elif rank_id==1: + os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2_3" import datetime from ppdiffusers import StableDiffusion3Pipeline @@ -86,14 +90,14 @@ def parse_args(): paddle_dtype=inference_dtype, ) -pipe.transformer = paddle.incubate.jit.inference( - pipe.transformer, - save_model_dir="./tmp/sd3", - enable_new_ir=True, - cache_static_model=True, - exp_enable_use_cutlass=True, - delete_pass_lists=["add_norm_fuse_pass"], -) +# pipe.transformer = paddle.incubate.jit.inference( +# pipe.transformer, +# save_model_dir="./tmp/sd3_parallel", +# enable_new_ir=True, +# cache_static_model=False, +# exp_enable_use_cutlass=True, +# delete_pass_lists=["add_norm_fuse_pass"], +# ) generator = paddle.Generator().manual_seed(42) prompt = "A cat holding a sign that says hello world" diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 2d3ace335..4bc02cdd2 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -15,7 +15,8 @@ import paddle import paddle.nn.functional as F from paddle import nn - +import paddle.distributed as dist +import paddle.distributed.fleet as fleet class SimplifiedSD3(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): @@ -35,6 +36,10 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)]) + + self.ffn = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffnc = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) @@ -131,9 +136,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): ) # ffn1 - ffn_output = self.ffn1[i](norm_hidden_states) + ffn_output = self.ffn[i](norm_hidden_states) ffn_output = F.gelu(ffn_output, approximate=True) - ffn_output = self.ffn2[i](ffn_output) + ffn_output = self.ffnc[i](ffn_output) if context_pre_only: ffn_output = gate_mlp.unsqueeze(1) * ffn_output diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 8188ab47c..1d3263faf 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -425,3 +425,30 @@ def custom_modify_weight(cls, state_dict): ], axis=-1, ) + from paddle.distributed.fleet.utils import recompute + import paddle.distributed as dist + import paddle.distributed.fleet as fleet + a,b = paddle.split( + state_dict[f"simplified_sd3.ffn1.{i}.weight"], + 2, + axis=-1, + ) + c,d = paddle.split( + state_dict[f"simplified_sd3.ffn1.{i}.bias"], + 2, + axis=-1, + ) + a1,b1 = paddle.split( + state_dict[f"simplified_sd3.ffn2.{i}.weight"], + 2, + axis=0, + ) + rank_id = dist.get_rank() + if rank_id==0: + state_dict[f"simplified_sd3.ffn.{i}.weight"] = a + state_dict[f"simplified_sd3.ffn.{i}.bias"] = c + state_dict[f"simplified_sd3.ffnc.{i}.weight"] = a1 + elif rank_id==1: + state_dict[f"simplified_sd3.ffn.{i}.weight"] = b + state_dict[f"simplified_sd3.ffn.{i}.bias"] = d + state_dict[f"simplified_sd3.ffnc.{i}.weight"] = b1 \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 85069e42f..be89fc33a 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -800,22 +800,22 @@ def __call__( latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - if self.inference_optimize_bp and self.do_classifier_free_guidance: - latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0) - timestep_input ,timestep_ = paddle.split(timestep,2,axis=0) - prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0) - pooled_prompt_embeds_input ,pooled_prompt_embeds_ = paddle.split(pooled_prompt_embeds,2,axis=0) + # if self.inference_optimize_bp and self.do_classifier_free_guidance: + # latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0) + # timestep_input ,timestep_ = paddle.split(timestep,2,axis=0) + # prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0) + # pooled_prompt_embeds_input ,pooled_prompt_embeds_ = paddle.split(pooled_prompt_embeds,2,axis=0) - dist.scatter(latent_input,[latent_input,latent_model_input_]) - dist.scatter(timestep_input,[timestep_input,timestep_]) - dist.scatter(prompt_embeds_input,[prompt_embeds_input,prompt_embeds_]) - dist.scatter(pooled_prompt_embeds_input,[pooled_prompt_embeds_input,pooled_prompt_embeds_]) - - else: - latent_input = latent_model_input - timestep_input = timestep - prompt_embeds_input = prompt_embeds - pooled_prompt_embeds_input = pooled_prompt_embeds + # dist.scatter(latent_input,[latent_input,latent_model_input_]) + # dist.scatter(timestep_input,[timestep_input,timestep_]) + # dist.scatter(prompt_embeds_input,[prompt_embeds_input,prompt_embeds_]) + # dist.scatter(pooled_prompt_embeds_input,[pooled_prompt_embeds_input,pooled_prompt_embeds_]) + + # else: + latent_input = latent_model_input + timestep_input = timestep + prompt_embeds_input = prompt_embeds + pooled_prompt_embeds_input = pooled_prompt_embeds model_output = self.transformer( hidden_states=latent_input, @@ -832,13 +832,13 @@ def __call__( else: output = model_output[0] - if self.inference_optimize_bp: - tmp_shape = output.shape - tmp_shape[0] *=2 - noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype) - dist.all_gather(noise_pred,output) - else: - noise_pred = output + # if self.inference_optimize_bp: + # tmp_shape = output.shape + # tmp_shape[0] *=2 + # noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype) + # dist.all_gather(noise_pred,output) + # else: + noise_pred = output From c92fd3cb86ff84cb00d2d5a5d833665d52e3675c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 21 Oct 2024 19:39:03 +0800 Subject: [PATCH 02/21] weight parallel --- ...t_to_image_generation-stable_diffusion_3.py | 18 +++++++++--------- .../ppdiffusers/models/simplified_sd3.py | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index cb7a8017b..e298ce6a9 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -48,7 +48,7 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" - # os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" + os.environ["INFERENCE_OPTIMIZE_TRITON"] = "False" if args.inference_optimize_bp: os.environ["INFERENCE_OPTIMIZE_BP"] = "True" if args.dtype == "float32": @@ -90,14 +90,14 @@ def parse_args(): paddle_dtype=inference_dtype, ) -# pipe.transformer = paddle.incubate.jit.inference( -# pipe.transformer, -# save_model_dir="./tmp/sd3_parallel", -# enable_new_ir=True, -# cache_static_model=False, -# exp_enable_use_cutlass=True, -# delete_pass_lists=["add_norm_fuse_pass"], -# ) +pipe.transformer = paddle.incubate.jit.inference( + pipe.transformer, + save_model_dir="./tmp/sd3_parallel", + enable_new_ir=False, + cache_static_model=False, + exp_enable_use_cutlass=True, + delete_pass_lists=["add_norm_fuse_pass"], +) generator = paddle.Generator().manual_seed(42) prompt = "A cat holding a sign that says hello world" diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 4bc02cdd2..f06ea22a6 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -15,6 +15,7 @@ import paddle import paddle.nn.functional as F from paddle import nn + import paddle.distributed as dist import paddle.distributed.fleet as fleet From 9a005045d6c1270d9aeda2d2f1512071b62e2f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 21 Oct 2024 20:52:43 +0800 Subject: [PATCH 03/21] weight parallel --- ..._to_image_generation-stable_diffusion_3.py | 4 +-- .../ppdiffusers/models/simplified_sd3.py | 18 ++++++++--- .../ppdiffusers/models/transformer_sd3.py | 31 +++++++------------ 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index e298ce6a9..cb9e9de69 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -48,7 +48,7 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" - os.environ["INFERENCE_OPTIMIZE_TRITON"] = "False" + # os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" if args.inference_optimize_bp: os.environ["INFERENCE_OPTIMIZE_BP"] = "True" if args.dtype == "float32": @@ -94,7 +94,7 @@ def parse_args(): pipe.transformer, save_model_dir="./tmp/sd3_parallel", enable_new_ir=False, - cache_static_model=False, + cache_static_model=True, exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index f06ea22a6..78f2a5d58 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -19,6 +19,8 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet +model_parallel_size=2 + class SimplifiedSD3(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() @@ -38,8 +40,9 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)]) - self.ffn = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffnc = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + if model_parallel_size > 1: + self.ffn = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffnc = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) @@ -137,9 +140,14 @@ def forward(self, hidden_states, encoder_hidden_states, temb): ) # ffn1 - ffn_output = self.ffn[i](norm_hidden_states) - ffn_output = F.gelu(ffn_output, approximate=True) - ffn_output = self.ffnc[i](ffn_output) + if model_parallel_size > 1: + ffn_output = self.ffn[i](norm_hidden_states) + ffn_output = F.gelu(ffn_output, approximate=True) + ffn_output = self.ffnc[i](ffn_output) + else: + ffn_output = self.ffn1[i](norm_hidden_states) + ffn_output = F.gelu(ffn_output, approximate=True) + ffn_output = self.ffn2[i](ffn_output) if context_pre_only: ffn_output = gate_mlp.unsqueeze(1) * ffn_output diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index 1d3263faf..f1ffae72a 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -428,27 +428,18 @@ def custom_modify_weight(cls, state_dict): from paddle.distributed.fleet.utils import recompute import paddle.distributed as dist import paddle.distributed.fleet as fleet - a,b = paddle.split( - state_dict[f"simplified_sd3.ffn1.{i}.weight"], - 2, - axis=-1, - ) - c,d = paddle.split( - state_dict[f"simplified_sd3.ffn1.{i}.bias"], - 2, - axis=-1, - ) - a1,b1 = paddle.split( + rank_id = dist.get_rank() + tmpc = paddle.split( state_dict[f"simplified_sd3.ffn2.{i}.weight"], 2, axis=0, ) - rank_id = dist.get_rank() - if rank_id==0: - state_dict[f"simplified_sd3.ffn.{i}.weight"] = a - state_dict[f"simplified_sd3.ffn.{i}.bias"] = c - state_dict[f"simplified_sd3.ffnc.{i}.weight"] = a1 - elif rank_id==1: - state_dict[f"simplified_sd3.ffn.{i}.weight"] = b - state_dict[f"simplified_sd3.ffn.{i}.bias"] = d - state_dict[f"simplified_sd3.ffnc.{i}.weight"] = b1 \ No newline at end of file + state_dict[f"simplified_sd3.ffnc.{i}.weight"] = tmpc[rank_id] + state_dict[f"simplified_sd3.ffnc.{i}.bias"] = state_dict[f"simplified_sd3.ffn2.{i}.bias"] + for placeholder in ["weight", "bias"]: + tmp = paddle.split( + state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], + 2, + axis=-1, + ) + state_dict[f"simplified_sd3.ffn.{i}.{placeholder}"] = tmp[rank_id] From 72e78f8767fa6b8122e2b3582b5d391acd6c3ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Thu, 24 Oct 2024 13:10:35 +0800 Subject: [PATCH 04/21] model parallel ffn rename --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 8 ++++---- ppdiffusers/ppdiffusers/models/transformer_sd3.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 78f2a5d58..3942aee16 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -41,8 +41,8 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)]) if model_parallel_size > 1: - self.ffn = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffnc = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) @@ -141,9 +141,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): # ffn1 if model_parallel_size > 1: - ffn_output = self.ffn[i](norm_hidden_states) + ffn_output = self.ffn1_mp[i](norm_hidden_states) ffn_output = F.gelu(ffn_output, approximate=True) - ffn_output = self.ffnc[i](ffn_output) + ffn_output = self.ffn2_mp[i](ffn_output) else: ffn_output = self.ffn1[i](norm_hidden_states) ffn_output = F.gelu(ffn_output, approximate=True) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index f1ffae72a..d11add9b8 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -434,12 +434,12 @@ def custom_modify_weight(cls, state_dict): 2, axis=0, ) - state_dict[f"simplified_sd3.ffnc.{i}.weight"] = tmpc[rank_id] - state_dict[f"simplified_sd3.ffnc.{i}.bias"] = state_dict[f"simplified_sd3.ffn2.{i}.bias"] + state_dict[f"simplified_sd3.ffn2_mp.{i}.weight"] = tmpc[rank_id] + state_dict[f"simplified_sd3.ffn2_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2.{i}.bias"] for placeholder in ["weight", "bias"]: tmp = paddle.split( state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], 2, axis=-1, ) - state_dict[f"simplified_sd3.ffn.{i}.{placeholder}"] = tmp[rank_id] + state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[rank_id] From e2e7a9707d74979670129fc626fe77d03a3d0fee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Sat, 26 Oct 2024 23:06:50 +0800 Subject: [PATCH 05/21] add all Model Parallel --- ..._to_image_generation-stable_diffusion_3.py | 4 +- .../ppdiffusers/models/simplified_sd3.py | 72 +++++++++++++------ .../ppdiffusers/models/transformer_sd3.py | 34 +++++---- 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index cb9e9de69..ae78d523d 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -92,9 +92,9 @@ def parse_args(): pipe.transformer = paddle.incubate.jit.inference( pipe.transformer, - save_model_dir="./tmp/sd3_parallel", + save_model_dir="./tmp/TP_sd3_parallel", enable_new_ir=False, - cache_static_model=True, + cache_static_model=False, exp_enable_use_cutlass=True, delete_pass_lists=["add_norm_fuse_pass"], ) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 3942aee16..87bc12a2a 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -35,19 +35,29 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) - self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) - self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) - self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)]) + if model_parallel_size > 1: + self.qkv_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.eqkv_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.to_out_linear_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.to_add_out_linear_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) self.ffn2_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_context_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_context_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + else: + self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) + self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) + self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + + self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) + self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) + self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) + self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) - self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) - self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) - self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) - self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) def forward(self, hidden_states, encoder_hidden_states, temb): print("--------------------this is simplified_sd3------------------------") @@ -112,27 +122,40 @@ def forward(self, hidden_states, encoder_hidden_states, temb): epsilon=1e-06, ) - qkv = self.qkv[i](norm_hidden_states) - eqkv = self.eqkv[i](norm_encoder_hidden_states) + if model_parallel_size > 1: + qkv = self.qkv_mp[i](norm_hidden_states) + eqkv = self.eqkv_mp[i](norm_encoder_hidden_states) + + else: + qkv = self.qkv[i](norm_hidden_states) + eqkv = self.eqkv[i](norm_encoder_hidden_states) + + q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) bs = hidden_states.shape[0] - q = q.reshape([bs, -1, 24, 64]) - k = k.reshape([bs, -1, 24, 64]) - v = v.reshape([bs, -1, 24, 64]) + hs = q.shape[2] + if model_parallel_size > 1: + q = q.reshape([bs, -1, 12, hs//12]) + k = k.reshape([bs, -1, 12, hs//12]) + v = v.reshape([bs, -1, 12, hs//12]) + else: + q = q.reshape([bs, -1, 24, 64]) + k = k.reshape([bs, -1, 24, 64]) + v = v.reshape([bs, -1, 24, 64]) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) - norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, self.dim]) + norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, hs]) attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1) # attn_output, context_attn_output = paddlemix.triton_ops.triton_split( # norm_hidden_states1, num_or_sections=[1024, 154], axis=1 # ) - attn_output = paddle.nn.functional.linear( - attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias - ) - - if not context_pre_only: + if model_parallel_size > 1: + attn_output = self.to_out_linear_mp[i](attn_output) + context_attn_output = self.to_add_out_linear_mp[i](context_attn_output) + else: + attn_output = self.to_out_linear[i](attn_output) context_attn_output = self.to_add_out_linear[i](context_attn_output) hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual( @@ -163,9 +186,14 @@ def forward(self, hidden_states, encoder_hidden_states, temb): encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 ) - context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states) - context_ffn_output = F.gelu(context_ffn_output, approximate=True) - context_ffn_output = self.ffn2_context[i](context_ffn_output) + if model_parallel_size > 1: + context_ffn_output = self.ffn1_context_mp[i](norm_encoder_hidden_states) + context_ffn_output = F.gelu(context_ffn_output, approximate=True) + context_ffn_output = self.ffn2_context_mp[i](context_ffn_output) + else: + context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states) + context_ffn_output = F.gelu(context_ffn_output, approximate=True) + context_ffn_output = self.ffn2_context[i](context_ffn_output) last_context_ffn_output = context_ffn_output last_context_hidden_states = encoder_hidden_states diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index d11add9b8..bae4b7dae 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -429,17 +429,27 @@ def custom_modify_weight(cls, state_dict): import paddle.distributed as dist import paddle.distributed.fleet as fleet rank_id = dist.get_rank() - tmpc = paddle.split( - state_dict[f"simplified_sd3.ffn2.{i}.weight"], - 2, - axis=0, - ) - state_dict[f"simplified_sd3.ffn2_mp.{i}.weight"] = tmpc[rank_id] - state_dict[f"simplified_sd3.ffn2_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2.{i}.bias"] + for mp_name in ["ffn2","to_out_linear"]: + tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[rank_id] + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[f"simplified_sd3.{mp_name}.{i}.bias"] + if i<23: + tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmpc[rank_id] + state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[f"simplified_sd3.to_add_out_linear.{i}.bias"] + tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmpf[rank_id] + state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2_context.{i}.bias"] + for placeholder in ["weight", "bias"]: + tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"],2,axis=-1) + state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmpf1[rank_id] for placeholder in ["weight", "bias"]: - tmp = paddle.split( - state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], - 2, - axis=-1, - ) + tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"],2,axis=-1) state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[rank_id] + for placeholder1 in ["", "e"]: + tmpq = paddle.split(state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"],6,axis=-1) + tmpqp = [ + paddle.concat([tmpq[0], tmpq[2],tmpq[4]],axis=-1), + paddle.concat([tmpq[1], tmpq[3],tmpq[5]],axis=-1) + ] + state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmpqp[rank_id] From 79eb8382619fa0949f6ab2bcb0ef88e2c13ffef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Sun, 27 Oct 2024 21:32:00 +0800 Subject: [PATCH 06/21] update MP --- .../ppdiffusers/models/simplified_sd3.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 87bc12a2a..009985804 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -18,7 +18,8 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet - +import paddle.distributed.fleet.meta_parallel.ColumnParallelLinear as CPLinear +import paddle.distributed.fleet.meta_parallel.RowParallelLinear as RPLinear model_parallel_size=2 class SimplifiedSD3(nn.Layer): @@ -32,21 +33,18 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.linear_context = nn.LayerList( [nn.Linear(self.dim, (6 if i < num_layers - 1 else 2) * self.dim) for i in range(num_layers)] ) - self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) - - if model_parallel_size > 1: - self.qkv_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.eqkv_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.to_out_linear_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.to_add_out_linear_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.qkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.eqkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.to_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.to_add_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_context_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_context_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_context_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_context_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) else: self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) @@ -135,9 +133,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): bs = hidden_states.shape[0] hs = q.shape[2] if model_parallel_size > 1: - q = q.reshape([bs, -1, 12, hs//12]) - k = k.reshape([bs, -1, 12, hs//12]) - v = v.reshape([bs, -1, 12, hs//12]) + q = q.reshape([bs, -1, hs//64, 64]) + k = k.reshape([bs, -1, hs//64, 64]) + v = v.reshape([bs, -1, hs//64, 64]) else: q = q.reshape([bs, -1, 24, 64]) k = k.reshape([bs, -1, 24, 64]) From 283886cdfeb98f7593faf8738e251f1dc724b047 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Sun, 27 Oct 2024 21:43:23 +0800 Subject: [PATCH 07/21] update MP --- .../ppdiffusers/models/simplified_sd3.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 009985804..9b207f7bc 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -18,8 +18,8 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.meta_parallel.ColumnParallelLinear as CPLinear -import paddle.distributed.fleet.meta_parallel.RowParallelLinear as RPLinear +import paddle.distributed.fleet.meta_parallel as FMPLinear +# import paddle.distributed.fleet.meta_parallel.RowParallelLinear as RPLinear model_parallel_size=2 class SimplifiedSD3(nn.Layer): @@ -36,15 +36,15 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) if model_parallel_size > 1: - self.qkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.eqkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.to_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.to_add_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.qkv_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.eqkv_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.to_out_linear_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.to_add_out_linear_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_context_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_context_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_context_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_context_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) else: self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) From 9a21b8b4b186d5077c9149c7a0e23b9a4e8b0aa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 28 Oct 2024 11:03:31 +0800 Subject: [PATCH 08/21] update MP --- .../ppdiffusers/models/simplified_sd3.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 9b207f7bc..d7a3de0f6 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -18,8 +18,10 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.meta_parallel as FMPLinear +# import paddle.distributed.fleet.meta_parallel as FMPLinear # import paddle.distributed.fleet.meta_parallel.RowParallelLinear as RPLinear +from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear +from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear model_parallel_size=2 class SimplifiedSD3(nn.Layer): @@ -36,15 +38,15 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) if model_parallel_size > 1: - self.qkv_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.eqkv_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.to_out_linear_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.to_add_out_linear_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.qkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.eqkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.to_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.to_add_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_context_mp = nn.LayerList([FMPLinear.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_context_mp = nn.LayerList([FMPLinear.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_context_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_context_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) else: self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) From dccec21504d28038b83bd9115bc0a08d2adfedf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 28 Oct 2024 11:59:04 +0800 Subject: [PATCH 09/21] update MP --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index d7a3de0f6..f45a33117 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -18,8 +18,6 @@ import paddle.distributed as dist import paddle.distributed.fleet as fleet -# import paddle.distributed.fleet.meta_parallel as FMPLinear -# import paddle.distributed.fleet.meta_parallel.RowParallelLinear as RPLinear from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear model_parallel_size=2 @@ -134,14 +132,9 @@ def forward(self, hidden_states, encoder_hidden_states, temb): q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) bs = hidden_states.shape[0] hs = q.shape[2] - if model_parallel_size > 1: - q = q.reshape([bs, -1, hs//64, 64]) - k = k.reshape([bs, -1, hs//64, 64]) - v = v.reshape([bs, -1, hs//64, 64]) - else: - q = q.reshape([bs, -1, 24, 64]) - k = k.reshape([bs, -1, 24, 64]) - v = v.reshape([bs, -1, 24, 64]) + q = q.reshape([bs, -1, hs//64, 64]) + k = k.reshape([bs, -1, hs//64, 64]) + v = v.reshape([bs, -1, hs//64, 64]) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, hs]) From 74572b303390b71ff3afb1e2ff22145fa2097e7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 28 Oct 2024 12:04:25 +0800 Subject: [PATCH 10/21] update mp --- .../ppdiffusers/models/simplified_sd3.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index f45a33117..3fd4d8e6e 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -15,6 +15,7 @@ import paddle import paddle.nn.functional as F from paddle import nn +from paddle.nn import LayerList as LayerList import paddle.distributed as dist import paddle.distributed.fleet as fleet @@ -29,32 +30,32 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.dim = dim self.silu = nn.Silu() - self.linear1 = nn.LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) - self.linear_context = nn.LayerList( + self.linear1 = LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) + self.linear_context = LayerList( [nn.Linear(self.dim, (6 if i < num_layers - 1 else 2) * self.dim) for i in range(num_layers)] ) self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) if model_parallel_size > 1: - self.qkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.eqkv_mp = nn.LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.to_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.to_add_out_linear_mp = nn.LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.qkv_mp = LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.eqkv_mp = LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.to_out_linear_mp = LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.to_add_out_linear_mp = LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_context_mp = nn.LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_context_mp = nn.LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_mp = LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_mp = LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.ffn1_context_mp = LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) + self.ffn2_context_mp = LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) else: - self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) - self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) - self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.qkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) + self.eqkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) + self.to_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + self.to_add_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) - self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) - self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) - self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) + self.ffn1 = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) + self.ffn2 = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) + self.ffn1_context = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) + self.ffn2_context = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) def forward(self, hidden_states, encoder_hidden_states, temb): From f1c4358a616693feb4597e8cc1b7dc1d3169ee45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Fri, 1 Nov 2024 15:06:51 +0800 Subject: [PATCH 11/21] update SD3 CFGP & T/MP --- ..._to_image_generation-stable_diffusion_3.py | 20 +++++-- .../ppdiffusers/models/simplified_sd3.py | 13 +++-- .../ppdiffusers/models/transformer_sd3.py | 50 +++++++++-------- .../pipeline_stable_diffusion_3.py | 56 ++++++++++--------- 4 files changed, 80 insertions(+), 59 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index ae78d523d..d2b3f8642 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -48,7 +48,7 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" - # os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" + os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" if args.inference_optimize_bp: os.environ["INFERENCE_OPTIMIZE_BP"] = "True" if args.dtype == "float32": @@ -66,7 +66,7 @@ def parse_args(): import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() model_parallel_size = 2 - data_parallel_size = 1 + data_parallel_size = 2 strategy.hybrid_configs = { "dp_degree": data_parallel_size, "mp_degree": model_parallel_size, @@ -75,11 +75,21 @@ def parse_args(): fleet.init(is_collective=True, strategy=strategy) hcg = fleet.get_hybrid_communicate_group() mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() rank_id = dist.get_rank() + # mp_group = hcg.get_model_parallel_group() + # dp_group = hcg.get_data_parallel_group() + mp_degree = hcg.get_model_parallel_world_size() + dp_degree = hcg.get_data_parallel_world_size() if rank_id==0: - os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2_2" + os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/0" elif rank_id==1: - os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2_3" + os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/1" + elif rank_id==2: + os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2" + elif rank_id==3: + os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/3" + import datetime from ppdiffusers import StableDiffusion3Pipeline @@ -92,7 +102,7 @@ def parse_args(): pipe.transformer = paddle.incubate.jit.inference( pipe.transformer, - save_model_dir="./tmp/TP_sd3_parallel", + save_model_dir="./tmp/1024_TP_sd3_parallel", enable_new_ir=False, cache_static_model=False, exp_enable_use_cutlass=True, diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 3fd4d8e6e..78df075b2 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -21,7 +21,8 @@ import paddle.distributed.fleet as fleet from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear -model_parallel_size=2 +hcg = fleet.get_hybrid_communicate_group() +mp_degree = hcg.get_model_parallel_world_size() class SimplifiedSD3(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): @@ -36,7 +37,7 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio ) self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) - if model_parallel_size > 1: + if mp_degree > 1: self.qkv_mp = LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) self.eqkv_mp = LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) self.to_out_linear_mp = LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) @@ -121,7 +122,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): epsilon=1e-06, ) - if model_parallel_size > 1: + if mp_degree > 1: qkv = self.qkv_mp[i](norm_hidden_states) eqkv = self.eqkv_mp[i](norm_encoder_hidden_states) @@ -145,7 +146,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): # norm_hidden_states1, num_or_sections=[1024, 154], axis=1 # ) - if model_parallel_size > 1: + if mp_degree > 1: attn_output = self.to_out_linear_mp[i](attn_output) context_attn_output = self.to_add_out_linear_mp[i](context_attn_output) else: @@ -157,7 +158,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): ) # ffn1 - if model_parallel_size > 1: + if mp_degree > 1: ffn_output = self.ffn1_mp[i](norm_hidden_states) ffn_output = F.gelu(ffn_output, approximate=True) ffn_output = self.ffn2_mp[i](ffn_output) @@ -180,7 +181,7 @@ def forward(self, hidden_states, encoder_hidden_states, temb): encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06 ) - if model_parallel_size > 1: + if mp_degree > 1: context_ffn_output = self.ffn1_context_mp[i](norm_encoder_hidden_states) context_ffn_output = F.gelu(context_ffn_output, approximate=True) context_ffn_output = self.ffn2_context_mp[i](context_ffn_output) diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index bae4b7dae..c1259adee 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -429,27 +429,31 @@ def custom_modify_weight(cls, state_dict): import paddle.distributed as dist import paddle.distributed.fleet as fleet rank_id = dist.get_rank() - for mp_name in ["ffn2","to_out_linear"]: - tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"],2,axis=0) - state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[rank_id] - state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[f"simplified_sd3.{mp_name}.{i}.bias"] - if i<23: - tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"],2,axis=0) - state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmpc[rank_id] - state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[f"simplified_sd3.to_add_out_linear.{i}.bias"] - tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"],2,axis=0) - state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmpf[rank_id] - state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2_context.{i}.bias"] + hcg = fleet.get_hybrid_communicate_group() + mp_id = hcg.get_model_parallel_rank() + mp_degree = hcg.get_model_parallel_world_size() + if mp_degree > 1: + for mp_name in ["ffn2","to_out_linear"]: + tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[mp_id] + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[f"simplified_sd3.{mp_name}.{i}.bias"] + if i<23: + tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmpc[mp_id] + state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[f"simplified_sd3.to_add_out_linear.{i}.bias"] + tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmpf[mp_id] + state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2_context.{i}.bias"] + for placeholder in ["weight", "bias"]: + tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"],2,axis=-1) + state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmpf1[mp_id] for placeholder in ["weight", "bias"]: - tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"],2,axis=-1) - state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmpf1[rank_id] - for placeholder in ["weight", "bias"]: - tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"],2,axis=-1) - state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[rank_id] - for placeholder1 in ["", "e"]: - tmpq = paddle.split(state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"],6,axis=-1) - tmpqp = [ - paddle.concat([tmpq[0], tmpq[2],tmpq[4]],axis=-1), - paddle.concat([tmpq[1], tmpq[3],tmpq[5]],axis=-1) - ] - state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmpqp[rank_id] + tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"],2,axis=-1) + state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[mp_id] + for placeholder1 in ["", "e"]: + tmpq = paddle.split(state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"],6,axis=-1) + tmpqp = [ + paddle.concat([tmpq[0], tmpq[2],tmpq[4]],axis=-1), + paddle.concat([tmpq[1], tmpq[3],tmpq[5]],axis=-1) + ] + state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmpqp[mp_id] diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index be89fc33a..fc36c23e0 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -18,7 +18,8 @@ import paddle import paddle.distributed as dist - +import paddle.distributed.fleet as fleet + from ppdiffusers.transformers import ( # T5TokenizerFast, CLIPTextModelWithProjection, CLIPTokenizer, @@ -196,7 +197,6 @@ def __init__( if hasattr(self, "transformer") and self.transformer is not None else 128 ) - self.inference_optimize_bp = os.getenv("INFERENCE_OPTIMIZE_BP") == "True" def _get_t5_prompt_embeds( self, @@ -800,22 +800,28 @@ def __call__( latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - # if self.inference_optimize_bp and self.do_classifier_free_guidance: - # latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0) - # timestep_input ,timestep_ = paddle.split(timestep,2,axis=0) - # prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0) - # pooled_prompt_embeds_input ,pooled_prompt_embeds_ = paddle.split(pooled_prompt_embeds,2,axis=0) + + hcg = fleet.get_hybrid_communicate_group() + dp_degree = hcg.get_data_parallel_world_size() + if dp_degree > 1 and self.do_classifier_free_guidance: + dp_id = hcg.get_data_parallel_rank() + dp_group = hcg.get_data_parallel_group() + + tmp_latent_input = paddle.split(latent_model_input,2,axis=0) + tmp_timestep = paddle.split(timestep,2,axis=0) + tmp_prompt_embeds = paddle.split(prompt_embeds,2,axis=0) + tmp_pooled_prompt_embeds = paddle.split(pooled_prompt_embeds,2,axis=0) + + latent_input =tmp_latent_input[dp_id] + timestep_input = tmp_timestep[dp_id] + prompt_embeds_input = tmp_prompt_embeds[dp_id] + pooled_prompt_embeds_input = tmp_pooled_prompt_embeds[dp_id] - # dist.scatter(latent_input,[latent_input,latent_model_input_]) - # dist.scatter(timestep_input,[timestep_input,timestep_]) - # dist.scatter(prompt_embeds_input,[prompt_embeds_input,prompt_embeds_]) - # dist.scatter(pooled_prompt_embeds_input,[pooled_prompt_embeds_input,pooled_prompt_embeds_]) - - # else: - latent_input = latent_model_input - timestep_input = timestep - prompt_embeds_input = prompt_embeds - pooled_prompt_embeds_input = pooled_prompt_embeds + else: + latent_input = latent_model_input + timestep_input = timestep + prompt_embeds_input = prompt_embeds + pooled_prompt_embeds_input = pooled_prompt_embeds model_output = self.transformer( hidden_states=latent_input, @@ -831,14 +837,14 @@ def __call__( output = model_output else: output = model_output[0] - - # if self.inference_optimize_bp: - # tmp_shape = output.shape - # tmp_shape[0] *=2 - # noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype) - # dist.all_gather(noise_pred,output) - # else: - noise_pred = output + + if dp_degree > 1 and self.do_classifier_free_guidance: + tmp_shape = output.shape + tmp_shape[0] *=2 + noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype) + dist.all_gather(noise_pred, output, group=dp_group) + else: + noise_pred = output From acbcd350a9cdca49116ae971817d04d44cefc0c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 4 Nov 2024 19:47:42 +0800 Subject: [PATCH 12/21] update MP&BP SD3 --- ..._to_image_generation-stable_diffusion_3.py | 64 +++++++----------- .../ppdiffusers/models/simplified_sd3.py | 65 ++++++++++++------- .../ppdiffusers/models/transformer_sd3.py | 40 +++++++----- .../pipeline_stable_diffusion_3.py | 32 ++++----- 4 files changed, 102 insertions(+), 99 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index d2b3f8642..831841d8b 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import argparse +import os + import paddle + + def parse_args(): parser = argparse.ArgumentParser( description=" Use PaddleMIX to accelerate the Stable Diffusion3 image generation model." @@ -40,6 +43,8 @@ def parse_args(): parser.add_argument("--width", type=int, default=512, help="Width of the generated image.") parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.") parser.add_argument("--dtype", type=str, default="float32", help="Inference data types.") + parser.add_argument("--mp_size", type=int, default=1, help="Inference data types.") + parser.add_argument("--dp_size", type=int, default=1, help="Inference data types.") return parser.parse_args() @@ -49,57 +54,39 @@ def parse_args(): if args.inference_optimize: os.environ["INFERENCE_OPTIMIZE"] = "True" os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True" -if args.inference_optimize_bp: - os.environ["INFERENCE_OPTIMIZE_BP"] = "True" if args.dtype == "float32": inference_dtype = paddle.float32 elif args.dtype == "float16": inference_dtype = paddle.float16 -if args.inference_optimize_bp: - from paddle.distributed import fleet - from paddle.distributed.fleet.utils import recompute - import numpy as np - import random - import paddle.distributed as dist - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - model_parallel_size = 2 - data_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": data_parallel_size, - "mp_degree": model_parallel_size, - "pp_degree": 1 - } - fleet.init(is_collective=True, strategy=strategy) - hcg = fleet.get_hybrid_communicate_group() - mp_id = hcg.get_model_parallel_rank() - dp_id = hcg.get_data_parallel_rank() - rank_id = dist.get_rank() - # mp_group = hcg.get_model_parallel_group() - # dp_group = hcg.get_data_parallel_group() - mp_degree = hcg.get_model_parallel_world_size() - dp_degree = hcg.get_data_parallel_world_size() - if rank_id==0: - os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/0" - elif rank_id==1: - os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/1" - elif rank_id==2: - os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2" - elif rank_id==3: - os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/3" +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +strategy = fleet.DistributedStrategy() +model_parallel_size = args.mp_size +data_parallel_size = args.dp_size +strategy.hybrid_configs = {"dp_degree": data_parallel_size, "mp_degree": model_parallel_size, "pp_degree": 1} +fleet.init(is_collective=True, strategy=strategy) +hcg = fleet.get_hybrid_communicate_group() +mp_id = hcg.get_model_parallel_rank() +dp_id = hcg.get_data_parallel_rank() +rank_id = dist.get_rank() +# mp_group = hcg.get_model_parallel_group() +# dp_group = hcg.get_data_parallel_group() +mp_degree = hcg.get_model_parallel_world_size() +dp_degree = hcg.get_data_parallel_world_size() + +os.environ["TRITON_KERNEL_CACHE_DIR"] = f"./tmp/sd3_parallel/{rank_id}" import datetime -from ppdiffusers import StableDiffusion3Pipeline +from ppdiffusers import StableDiffusion3Pipeline pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", paddle_dtype=inference_dtype, ) - pipe.transformer = paddle.incubate.jit.inference( pipe.transformer, save_model_dir="./tmp/1024_TP_sd3_parallel", @@ -151,8 +138,7 @@ def parse_args(): cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") -if args.inference_optimize_bp: - if rank_id == 0: + if dp_degree > 1 or mp_degree > 1: image.save("text_to_image_generation-stable_diffusion_3-result.png") else: image.save("text_to_image_generation-stable_diffusion_3-result.png") diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 78df075b2..f13d6be1a 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -13,17 +13,17 @@ # limitations under the License. import paddle +import paddle.distributed.fleet as fleet import paddle.nn.functional as F from paddle import nn +from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear +from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear from paddle.nn import LayerList as LayerList -import paddle.distributed as dist -import paddle.distributed.fleet as fleet -from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear -from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear hcg = fleet.get_hybrid_communicate_group() mp_degree = hcg.get_model_parallel_world_size() + class SimplifiedSD3(nn.Layer): def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): super().__init__() @@ -38,27 +38,47 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True) if mp_degree > 1: - self.qkv_mp = LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.eqkv_mp = LayerList([CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.to_out_linear_mp = LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.to_add_out_linear_mp = LayerList([RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - - self.ffn1_mp = LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_mp = LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) - self.ffn1_context_mp = LayerList([CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]) - self.ffn2_context_mp = LayerList([RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]) + self.qkv_mp = LayerList( + [CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)] + ) + self.eqkv_mp = LayerList( + [CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)] + ) + self.to_out_linear_mp = LayerList( + [RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)] + ) + # When using Model Parallel, for the symmetry of GEMM, we change num_layers-1 here to num_layers, which has no effect on the results. + self.to_add_out_linear_mp = LayerList( + [RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)] + ) + + self.ffn1_mp = LayerList( + [CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)] + ) + self.ffn2_mp = LayerList( + [RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)] + ) + self.ffn1_context_mp = LayerList( + [CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers - 1)] + ) + self.ffn2_context_mp = LayerList( + [ + RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) + for i in range(num_layers - 1) + ] + ) else: self.qkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.eqkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)]) self.to_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) + # When using Model Parallel, for the symmetry of GEMM, we change num_layers-1 here to num_layers, which has no effect on the results. self.to_add_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)]) - + self.ffn1 = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)]) self.ffn2 = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)]) self.ffn1_context = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)]) self.ffn2_context = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)]) - def forward(self, hidden_states, encoder_hidden_states, temb): print("--------------------this is simplified_sd3------------------------") temb_silu = self.silu(temb) @@ -125,21 +145,20 @@ def forward(self, hidden_states, encoder_hidden_states, temb): if mp_degree > 1: qkv = self.qkv_mp[i](norm_hidden_states) eqkv = self.eqkv_mp[i](norm_encoder_hidden_states) - + else: qkv = self.qkv[i](norm_hidden_states) eqkv = self.eqkv[i](norm_encoder_hidden_states) - q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) bs = hidden_states.shape[0] - hs = q.shape[2] - q = q.reshape([bs, -1, hs//64, 64]) - k = k.reshape([bs, -1, hs//64, 64]) - v = v.reshape([bs, -1, hs//64, 64]) + head_nums = q.shape[2] // 64 + q = q.reshape([bs, -1, head_nums, 64]) + k = k.reshape([bs, -1, head_nums, 64]) + v = v.reshape([bs, -1, head_nums, 64]) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) - norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, hs]) + norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, head_nums * 64]) attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1) # attn_output, context_attn_output = paddlemix.triton_ops.triton_split( @@ -194,4 +213,4 @@ def forward(self, hidden_states, encoder_hidden_states, temb): last_context_hidden_states = encoder_hidden_states last_context_gate_mlp = c_gate_mlp - return hidden_states + return hidden_states diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index c1259adee..d624bce09 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -425,35 +425,41 @@ def custom_modify_weight(cls, state_dict): ], axis=-1, ) - from paddle.distributed.fleet.utils import recompute - import paddle.distributed as dist import paddle.distributed.fleet as fleet - rank_id = dist.get_rank() + hcg = fleet.get_hybrid_communicate_group() mp_id = hcg.get_model_parallel_rank() mp_degree = hcg.get_model_parallel_world_size() if mp_degree > 1: - for mp_name in ["ffn2","to_out_linear"]: - tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"],2,axis=0) - state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[mp_id] - state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[f"simplified_sd3.{mp_name}.{i}.bias"] - if i<23: - tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"],2,axis=0) + if i < 23: + tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"], 2, axis=0) state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmpc[mp_id] - state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[f"simplified_sd3.to_add_out_linear.{i}.bias"] - tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"],2,axis=0) + state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[ + f"simplified_sd3.to_add_out_linear.{i}.bias" + ] + tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"], 2, axis=0) state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmpf[mp_id] - state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2_context.{i}.bias"] + state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[ + f"simplified_sd3.ffn2_context.{i}.bias" + ] for placeholder in ["weight", "bias"]: - tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"],2,axis=-1) + tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"], 2, axis=-1) state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmpf1[mp_id] for placeholder in ["weight", "bias"]: - tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"],2,axis=-1) + tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], 2, axis=-1) state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[mp_id] for placeholder1 in ["", "e"]: - tmpq = paddle.split(state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"],6,axis=-1) + tmpq = paddle.split( + state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"], 6, axis=-1 + ) tmpqp = [ - paddle.concat([tmpq[0], tmpq[2],tmpq[4]],axis=-1), - paddle.concat([tmpq[1], tmpq[3],tmpq[5]],axis=-1) + paddle.concat([tmpq[0], tmpq[2], tmpq[4]], axis=-1), + paddle.concat([tmpq[1], tmpq[3], tmpq[5]], axis=-1), ] state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmpqp[mp_id] + for mp_name in ["ffn2", "to_out_linear"]: + tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"], 2, axis=0) + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[mp_id] + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[ + f"simplified_sd3.{mp_name}.{i}.bias" + ] diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index fc36c23e0..74855b3e8 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -12,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import inspect from typing import Any, Callable, Dict, List, Optional, Union import paddle import paddle.distributed as dist import paddle.distributed.fleet as fleet - + from ppdiffusers.transformers import ( # T5TokenizerFast, CLIPTextModelWithProjection, CLIPTokenizer, @@ -121,7 +120,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): r""" Args: @@ -800,29 +799,24 @@ def __call__( latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - + hcg = fleet.get_hybrid_communicate_group() dp_degree = hcg.get_data_parallel_world_size() if dp_degree > 1 and self.do_classifier_free_guidance: dp_id = hcg.get_data_parallel_rank() dp_group = hcg.get_data_parallel_group() - - tmp_latent_input = paddle.split(latent_model_input,2,axis=0) - tmp_timestep = paddle.split(timestep,2,axis=0) - tmp_prompt_embeds = paddle.split(prompt_embeds,2,axis=0) - tmp_pooled_prompt_embeds = paddle.split(pooled_prompt_embeds,2,axis=0) - - latent_input =tmp_latent_input[dp_id] - timestep_input = tmp_timestep[dp_id] - prompt_embeds_input = tmp_prompt_embeds[dp_id] - pooled_prompt_embeds_input = tmp_pooled_prompt_embeds[dp_id] - + + latent_input = paddle.split(latent_model_input, 2, axis=0)[dp_id] + timestep_input = paddle.split(timestep, 2, axis=0)[dp_id] + prompt_embeds_input = paddle.split(prompt_embeds, 2, axis=0)[dp_id] + pooled_prompt_embeds_input = paddle.split(pooled_prompt_embeds, 2, axis=0)[dp_id] + else: latent_input = latent_model_input timestep_input = timestep prompt_embeds_input = prompt_embeds pooled_prompt_embeds_input = pooled_prompt_embeds - + model_output = self.transformer( hidden_states=latent_input, timestep=timestep_input, @@ -840,14 +834,12 @@ def __call__( if dp_degree > 1 and self.do_classifier_free_guidance: tmp_shape = output.shape - tmp_shape[0] *=2 - noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype) + tmp_shape[0] *= 2 + noise_pred = paddle.zeros(tmp_shape, dtype=output.dtype) dist.all_gather(noise_pred, output, group=dp_group) else: noise_pred = output - - # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) From d76d5bda60c8cecb85fdd639f61c0eb39034fb18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Mon, 4 Nov 2024 20:35:41 +0800 Subject: [PATCH 13/21] update SD3 --- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index f13d6be1a..60d7ffee1 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -29,6 +29,7 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio super().__init__() self.num_layers = num_layers self.dim = dim + self.head_dim = 64 self.silu = nn.Silu() self.linear1 = LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)]) @@ -152,13 +153,13 @@ def forward(self, hidden_states, encoder_hidden_states, temb): q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv) bs = hidden_states.shape[0] - head_nums = q.shape[2] // 64 - q = q.reshape([bs, -1, head_nums, 64]) - k = k.reshape([bs, -1, head_nums, 64]) - v = v.reshape([bs, -1, head_nums, 64]) + head_nums = q.shape[2] // self.head_dim + q = q.reshape([bs, -1, head_nums, self.head_dim]) + k = k.reshape([bs, -1, head_nums, self.head_dim]) + v = v.reshape([bs, -1, head_nums, self.head_dim]) norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False) - norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, head_nums * 64]) + norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, head_nums * self.head_dim]) attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1) # attn_output, context_attn_output = paddlemix.triton_ops.triton_split( From 1069163f1f6b4eeca47e252309d5c44e9bffee8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 5 Nov 2024 17:29:03 +0800 Subject: [PATCH 14/21] update SD3 parallel --- .../text_to_image_generation-stable_diffusion_3.py | 10 ++++++---- ppdiffusers/ppdiffusers/models/simplified_sd3.py | 3 +-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 831841d8b..ff2d8c395 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -43,8 +43,12 @@ def parse_args(): parser.add_argument("--width", type=int, default=512, help="Width of the generated image.") parser.add_argument("--num-inference-steps", type=int, default=50, help="Number of inference steps.") parser.add_argument("--dtype", type=str, default="float32", help="Inference data types.") - parser.add_argument("--mp_size", type=int, default=1, help="Inference data types.") - parser.add_argument("--dp_size", type=int, default=1, help="Inference data types.") + parser.add_argument( + "--mp_size", type=int, default=1, help="This size refers to the degree of parallelism using model parallel." + ) + parser.add_argument( + "--dp_size", type=int, default=1, help="This size refers to the degree of parallelism using data parallel." + ) return parser.parse_args() @@ -72,8 +76,6 @@ def parse_args(): mp_id = hcg.get_model_parallel_rank() dp_id = hcg.get_data_parallel_rank() rank_id = dist.get_rank() -# mp_group = hcg.get_model_parallel_group() -# dp_group = hcg.get_data_parallel_group() mp_degree = hcg.get_model_parallel_world_size() dp_degree = hcg.get_data_parallel_world_size() diff --git a/ppdiffusers/ppdiffusers/models/simplified_sd3.py b/ppdiffusers/ppdiffusers/models/simplified_sd3.py index 60d7ffee1..b1cb4bada 100644 --- a/ppdiffusers/ppdiffusers/models/simplified_sd3.py +++ b/ppdiffusers/ppdiffusers/models/simplified_sd3.py @@ -20,8 +20,7 @@ from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear from paddle.nn import LayerList as LayerList -hcg = fleet.get_hybrid_communicate_group() -mp_degree = hcg.get_model_parallel_world_size() +mp_degree = fleet.get_hybrid_communicate_group().get_model_parallel_world_size() class SimplifiedSD3(nn.Layer): From 643332e37022522706e9fa6e1f3c023ee701177c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 5 Nov 2024 19:13:28 +0800 Subject: [PATCH 15/21] merge SD3 --- .../deploy/sd3/text_to_image_generation-stable_diffusion_3.py | 2 +- .../inference/text_to_image_generation-stable_diffusion_3.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py index ff2d8c395..f848a893e 100644 --- a/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py index 90c6450e3..861a8ed1c 100644 --- a/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/examples/inference/text_to_image_generation-stable_diffusion_3.py @@ -13,7 +13,9 @@ # limitations under the License. import paddle + from ppdiffusers import StableDiffusion3Pipeline + pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", paddle_dtype=paddle.float16 ) From ddf0f5622f58f7b9f68ffaecfdf4ccbb3351f2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 5 Nov 2024 19:37:20 +0800 Subject: [PATCH 16/21] update sd3 README --- ppdiffusers/deploy/sd3/README.md | 40 +++++++++++++++++++------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index b3d7f19b2..75d8f0b2b 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -31,28 +31,36 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height | 1.2 s | 1.78 s | 4.202 s | -## Paddle Stable Diffusion 3 模型多卡推理: -### batch parallel 实现原理 -- 在SD3中,对于输入是一个prompt时,使用CFG需要同时进行unconditional guide和text guide的生成,此时 MM-DiT-blocks 的输入batch_size=2; -所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。 -计算完成后,我们再把两张卡的计算结果 聚合在一起,结果与单卡计算完全一致。 -### 开启多卡推理方法 -- Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `--inference_optimize_bp 1` 来开启这一功能, -使用 `python -m paddle.distributed.launch --gpus 0,1` 指定使用哪些卡进行推理。 +## Paddle Stable Diffusion 3 模型多卡推理: +### Data Parallel 实现原理 +- 在SD3中,对于输入是一个prompt时,使用CFG需要同时进行unconditional guide和text guide的生成,此时 MM-DiT-blocks 的输入batch_size=2; +所以我们考虑在多卡并行的方案中,将batch为2的输入拆分到两张卡上进行计算,这样单卡的计算量就减少为原来的一半,降低了单卡所承载的浮点计算量。 +计算完成后,我们再把两张卡的计算结果聚合在一起,结果与单卡计算完全一致。 + +### Model parallel 实现原理 +- 在SD3中,在Linear和Attnetion中有大量的GEMM(General Matrix Multiply),当生成高分辨率图像时,GEMM的计算量以及模型的预训练权重大小都呈线性递增。 +因此,我们考虑在多卡并行方案中,将模型的这些GEMM拆分到两张卡上进行计算,这样单卡的计算量和权重大小就都减少为原来的一半,不仅降低了单卡所承载的浮点计算量,也降低了单卡的显存占用。 + +### 开启多卡推理方法 +- Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `mp_size 2` 来开启Model Parallel,使用 `dp_size 2`来开启Data Parallel。 +使用 `python -m paddle.distributed.launch --gpus “0,1,2,3”` 指定使用哪些卡进行推理,其中`--gpus “0,1,2,3”`即为启用的GPU卡号。 +如果只需使用两卡推理,则只需指定两卡即可,如 `python -m paddle.distributed.launch --gpus “0,1”`。同时需要指定使用的并行方法及并行度,如 `mp_size 2` 或者 `dp_size 2`。 +注意,这里的`mp_size`需要设定为不大于输入的batch_size个,且`mp_size`和`dp_size`的和不能超过机器总卡数。 高性能多卡推理指令: ```shell # 执行多卡推理指令 -python -m paddle.distributed.launch --gpus 0,1 text_to_image_generation-stable_diffusion_3.py \ +python -m paddle.distributed.launch --gpus "0,1,2,3" text_to_image_generation-stable_diffusion_3.py \ --dtype float16 \ ---height 512 --width 512 \ ---num-inference-steps 50 \ +--height 1024 \ +--width 1024 \ +--num-inference-steps 20 \ --inference_optimize 1 \ ---inference_optimize_bp 1 \ +--mp_size 2 \ +--dp_size 2 \ --benchmark 1 ``` ## 在 NVIDIA A800-SXM4-80GB 上测试的性能如下: - -| Paddle batch parallel | Paddle Single Card | PyTorch | TensorRT | Paddle 动态图 | -| --------------------- | ------------------ | --------- | -------- | ------------ | -| 0.86 s | 1.2 s | 1.78 s | 1.16 s | 4.202 s |​⬤ \ No newline at end of file +| Paddle mp_size=2 & dp_size=2 | Paddle mp_size=2 | Paddle dp_size=2 | Paddle 动态图 | +| ---------------------------- | ------------------- | ---------------- | ------------ | +| 0.99s | 1.581 s | 1.319 s | 4.202 s |​ From 06c81c1318179c103acc7b43af02aeef544de56d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 5 Nov 2024 19:38:55 +0800 Subject: [PATCH 17/21] update sd3 README --- ppdiffusers/deploy/sd3/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index 75d8f0b2b..a1dd923fa 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -45,8 +45,9 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height - Paddle Inference 提供了SD3模型的多卡推理功能,用户可以通过设置 `mp_size 2` 来开启Model Parallel,使用 `dp_size 2`来开启Data Parallel。 使用 `python -m paddle.distributed.launch --gpus “0,1,2,3”` 指定使用哪些卡进行推理,其中`--gpus “0,1,2,3”`即为启用的GPU卡号。 如果只需使用两卡推理,则只需指定两卡即可,如 `python -m paddle.distributed.launch --gpus “0,1”`。同时需要指定使用的并行方法及并行度,如 `mp_size 2` 或者 `dp_size 2`。 -注意,这里的`mp_size`需要设定为不大于输入的batch_size个,且`mp_size`和`dp_size`的和不能超过机器总卡数。 -高性能多卡推理指令: + +- 注意,这里的`mp_size`需要设定为不大于输入的batch_size个,且`mp_size`和`dp_size`的和不能超过机器总卡数。 +- 高性能多卡推理指令: ```shell # 执行多卡推理指令 python -m paddle.distributed.launch --gpus "0,1,2,3" text_to_image_generation-stable_diffusion_3.py \ From a95bb4e43e5d4e98f8cea990277650c8ecd7e7b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 5 Nov 2024 19:44:00 +0800 Subject: [PATCH 18/21] update sd3 README --- ppdiffusers/deploy/sd3/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index a1dd923fa..235871d1a 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -62,6 +62,6 @@ python -m paddle.distributed.launch --gpus "0,1,2,3" text_to_image_generation-st ``` ## 在 NVIDIA A800-SXM4-80GB 上测试的性能如下: -| Paddle mp_size=2 & dp_size=2 | Paddle mp_size=2 | Paddle dp_size=2 | Paddle 动态图 | -| ---------------------------- | ------------------- | ---------------- | ------------ | -| 0.99s | 1.581 s | 1.319 s | 4.202 s |​ +| Paddle mp_size=2 & dp_size=2 | Paddle mp_size=2 | Paddle dp_size=2 | Paddle Single Card | Paddle 动态图 | +| ---------------------------- | ------------------- | ---------------- | ------------------ | ------------ | +| 0.99s | 1.581 s | 1.319 s | 2.376 s | 3.2 s | From e92154aa67db93dd80ccfd8c795f297a6702034c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Thu, 7 Nov 2024 10:43:20 +0800 Subject: [PATCH 19/21] update SD3 & doc --- ppdiffusers/deploy/sd3/README.md | 10 ++++++++++ .../ppdiffusers/models/transformer_sd3.py | 20 +++++++++---------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/ppdiffusers/deploy/sd3/README.md b/ppdiffusers/deploy/sd3/README.md index 235871d1a..b2804832d 100644 --- a/ppdiffusers/deploy/sd3/README.md +++ b/ppdiffusers/deploy/sd3/README.md @@ -11,9 +11,15 @@ python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compati # 安装develop版本的paddle,请根据自己的cuda版本选择对应的paddle版本,这里选择12.3的cuda版本 python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ +# 安装paddlemix库,使用集成在paddlemix库中的自定义算子。 +python -m pip install paddlemix + # 指定 libCutlassGemmEpilogue.so 的路径 # 详情请参考 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/README.md export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH +- 请注意,该项用于在静态图推理时利用Cutlass融合算子提升推理性能,但是并不是必须项。 +如果不使用Cutlass可以将`./text_to_image_generation-stable_diffusion_3.py`中的`exp_enable_use_cutlass`设为False。 +- ``` 高性能推理指令: @@ -23,6 +29,8 @@ python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height --num-inference-steps 50 --inference_optimize 1 \ --benchmark 1 ``` +注:--inference_optimize 1 用于开启推理优化,--benchmark 1 用于开启性能测试。 + - 在 NVIDIA A100-SXM4-40GB 上测试的性能如下: @@ -60,6 +68,8 @@ python -m paddle.distributed.launch --gpus "0,1,2,3" text_to_image_generation-st --dp_size 2 \ --benchmark 1 ``` +注:--inference_optimize 1 用于开启推理优化,--benchmark 1 用于开启性能测试。 + ## 在 NVIDIA A800-SXM4-80GB 上测试的性能如下: | Paddle mp_size=2 & dp_size=2 | Paddle mp_size=2 | Paddle dp_size=2 | Paddle Single Card | Paddle 动态图 | diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index d624bce09..b4819e482 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -432,19 +432,19 @@ def custom_modify_weight(cls, state_dict): mp_degree = hcg.get_model_parallel_world_size() if mp_degree > 1: if i < 23: - tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"], 2, axis=0) - state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmpc[mp_id] + tmp = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"], 2, axis=0) + state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmp[mp_id] state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[ f"simplified_sd3.to_add_out_linear.{i}.bias" ] - tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"], 2, axis=0) - state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmpf[mp_id] + tmp = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"], 2, axis=0) + state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmp[mp_id] state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[ f"simplified_sd3.ffn2_context.{i}.bias" ] for placeholder in ["weight", "bias"]: - tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"], 2, axis=-1) - state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmpf1[mp_id] + tmp = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"], 2, axis=-1) + state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmp[mp_id] for placeholder in ["weight", "bias"]: tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], 2, axis=-1) state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[mp_id] @@ -452,14 +452,14 @@ def custom_modify_weight(cls, state_dict): tmpq = paddle.split( state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"], 6, axis=-1 ) - tmpqp = [ + tmp = [ paddle.concat([tmpq[0], tmpq[2], tmpq[4]], axis=-1), paddle.concat([tmpq[1], tmpq[3], tmpq[5]], axis=-1), ] - state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmpqp[mp_id] + state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmp[mp_id] for mp_name in ["ffn2", "to_out_linear"]: - tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"], 2, axis=0) - state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[mp_id] + tmp = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"], 2, axis=0) + state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmp[mp_id] state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[ f"simplified_sd3.{mp_name}.{i}.bias" ] From b6e08c0ebd9b286caa106b092d7a6f4d01342bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Thu, 14 Nov 2024 15:37:11 +0800 Subject: [PATCH 20/21] update sd3 parallel --- ..._to_image_generation-stable_diffusion_3.py | 10 ++++++++ .../ppdiffusers/models/transformer_sd3.py | 24 +++++++++---------- .../pipeline_stable_diffusion_3.py | 6 +++-- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py b/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py index f848a893e..4157ae855 100644 --- a/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py +++ b/ppdiffusers/deploy/sd3/text_to_image_generation-stable_diffusion_3.py @@ -89,6 +89,8 @@ def parse_args(): "stabilityai/stable-diffusion-3-medium-diffusers", paddle_dtype=inference_dtype, ) + + pipe.transformer = paddle.incubate.jit.inference( pipe.transformer, save_model_dir="./tmp/1024_TP_sd3_parallel", @@ -135,8 +137,16 @@ def parse_args(): duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 sumtime += duringtime print("SD3 end to end time : ", duringtime, "ms") + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") print("SD3 ave end to end time : ", sumtime / repeat_times, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") diff --git a/ppdiffusers/ppdiffusers/models/transformer_sd3.py b/ppdiffusers/ppdiffusers/models/transformer_sd3.py index b4819e482..29531379b 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_sd3.py +++ b/ppdiffusers/ppdiffusers/models/transformer_sd3.py @@ -432,33 +432,33 @@ def custom_modify_weight(cls, state_dict): mp_degree = hcg.get_model_parallel_world_size() if mp_degree > 1: if i < 23: - tmp = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"], 2, axis=0) + tmp = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"], mp_degree, axis=0) state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmp[mp_id] state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[ f"simplified_sd3.to_add_out_linear.{i}.bias" ] - tmp = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"], 2, axis=0) + tmp = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"], mp_degree, axis=0) state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmp[mp_id] state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[ f"simplified_sd3.ffn2_context.{i}.bias" ] for placeholder in ["weight", "bias"]: - tmp = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"], 2, axis=-1) + tmp = paddle.split( + state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"], mp_degree, axis=-1 + ) state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmp[mp_id] for placeholder in ["weight", "bias"]: - tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], 2, axis=-1) + tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"], mp_degree, axis=-1) state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[mp_id] for placeholder1 in ["", "e"]: - tmpq = paddle.split( - state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"], 6, axis=-1 + tmp = paddle.split( + state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"], 3 * mp_degree, axis=-1 + ) + state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = paddle.concat( + [tmp[mp_id], tmp[1 * mp_degree + mp_id], tmp[2 * mp_degree + mp_id]], axis=-1 ) - tmp = [ - paddle.concat([tmpq[0], tmpq[2], tmpq[4]], axis=-1), - paddle.concat([tmpq[1], tmpq[3], tmpq[5]], axis=-1), - ] - state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmp[mp_id] for mp_name in ["ffn2", "to_out_linear"]: - tmp = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"], 2, axis=0) + tmp = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"], mp_degree, axis=0) state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmp[mp_id] state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[ f"simplified_sd3.{mp_name}.{i}.bias" diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 74855b3e8..0ae9102a6 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -802,7 +802,9 @@ def __call__( hcg = fleet.get_hybrid_communicate_group() dp_degree = hcg.get_data_parallel_world_size() - if dp_degree > 1 and self.do_classifier_free_guidance: + enabled_cfg_dp = True if dp_degree > 1 and self.do_classifier_free_guidance else False + + if enabled_cfg_dp: dp_id = hcg.get_data_parallel_rank() dp_group = hcg.get_data_parallel_group() @@ -832,7 +834,7 @@ def __call__( else: output = model_output[0] - if dp_degree > 1 and self.do_classifier_free_guidance: + if enabled_cfg_dp: tmp_shape = output.shape tmp_shape[0] *= 2 noise_pred = paddle.zeros(tmp_shape, dtype=output.dtype) From 84154a9ddfb997a51572f96ae704ebe0d4332edd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Thu, 14 Nov 2024 16:38:10 +0800 Subject: [PATCH 21/21] update sd3 --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 0ae9102a6..455822b73 100644 --- a/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/ppdiffusers/ppdiffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -802,7 +802,7 @@ def __call__( hcg = fleet.get_hybrid_communicate_group() dp_degree = hcg.get_data_parallel_world_size() - enabled_cfg_dp = True if dp_degree > 1 and self.do_classifier_free_guidance else False + enabled_cfg_dp = dp_degree > 1 and self.do_classifier_free_guidance if enabled_cfg_dp: dp_id = hcg.get_data_parallel_rank()