Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weight parallel #773

Open
wants to merge 23 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse_args():

if args.inference_optimize:
os.environ["INFERENCE_OPTIMIZE"] = "True"
os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True"
# os.environ["INFERENCE_OPTIMIZE_TRITON"] = "True"
if args.inference_optimize_bp:
os.environ["INFERENCE_OPTIMIZE_BP"] = "True"
if args.dtype == "float32":
Expand Down Expand Up @@ -76,6 +76,10 @@ def parse_args():
hcg = fleet.get_hybrid_communicate_group()
mp_id = hcg.get_model_parallel_rank()
rank_id = dist.get_rank()
if rank_id==0:
os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2_2"
elif rank_id==1:
os.environ["TRITON_KERNEL_CACHE_DIR"]="./tmp/sd3_parallel/2_3"

import datetime
from ppdiffusers import StableDiffusion3Pipeline
Expand All @@ -88,9 +92,9 @@ def parse_args():

pipe.transformer = paddle.incubate.jit.inference(
pipe.transformer,
save_model_dir="./tmp/sd3",
enable_new_ir=True,
cache_static_model=True,
save_model_dir="./tmp/TP_sd3_parallel",
enable_new_ir=False,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)
Expand Down
92 changes: 67 additions & 25 deletions ppdiffusers/ppdiffusers/models/simplified_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import paddle.nn.functional as F
from paddle import nn

import paddle.distributed as dist
import paddle.distributed.fleet as fleet

model_parallel_size=2

class SimplifiedSD3(nn.Layer):
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int):
Expand All @@ -31,14 +35,29 @@ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attentio

self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True)

self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])
self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers - 1)])
self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)])
self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)])
self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)])


if model_parallel_size > 1:
self.qkv_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)])
Copy link
Contributor

@zhoutianzi666 zhoutianzi666 Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from fleet.meta_parallel import RowParallelLinear
from fleet.meta_parallel import ColumnParallelLinear

被频繁使用了,应该预先import as吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,已修改~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from fleet.meta_parallel import RowParallelLinear
from fleet.meta_parallel import ColumnParallelLinear

from fleet.meta_parallel import RowParallelLinear
from fleet.meta_parallel import ColumnParallelLinear
改成这样

self.eqkv_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)])
self.to_out_linear_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)])
self.to_add_out_linear_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)])

self.ffn1_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)])
self.ffn2_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)])
self.ffn1_context_mp = nn.LayerList([fleet.meta_parallel.ColumnParallelLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)])
self.ffn2_context_mp = nn.LayerList([fleet.meta_parallel.RowParallelLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)])
else:
self.qkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.eqkv = nn.LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
self.to_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])
self.to_add_out_linear = nn.LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])

self.ffn1 = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)])
self.ffn2 = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)])
self.ffn1_context = nn.LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)])
self.ffn2_context = nn.LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)])


def forward(self, hidden_states, encoder_hidden_states, temb):
print("--------------------this is simplified_sd3------------------------")
Expand Down Expand Up @@ -103,37 +122,55 @@ def forward(self, hidden_states, encoder_hidden_states, temb):
epsilon=1e-06,
)

qkv = self.qkv[i](norm_hidden_states)
eqkv = self.eqkv[i](norm_encoder_hidden_states)
if model_parallel_size > 1:
qkv = self.qkv_mp[i](norm_hidden_states)
eqkv = self.eqkv_mp[i](norm_encoder_hidden_states)

else:
qkv = self.qkv[i](norm_hidden_states)
eqkv = self.eqkv[i](norm_encoder_hidden_states)


q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv)
bs = hidden_states.shape[0]
q = q.reshape([bs, -1, 24, 64])
k = k.reshape([bs, -1, 24, 64])
v = v.reshape([bs, -1, 24, 64])
hs = q.shape[2]
if model_parallel_size > 1:
q = q.reshape([bs, -1, 12, hs//12])
k = k.reshape([bs, -1, 12, hs//12])
v = v.reshape([bs, -1, 12, hs//12])
else:
q = q.reshape([bs, -1, 24, 64])
k = k.reshape([bs, -1, 24, 64])
v = v.reshape([bs, -1, 24, 64])

norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False)
norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, self.dim])
norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, hs])
attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1)

# attn_output, context_attn_output = paddlemix.triton_ops.triton_split(
# norm_hidden_states1, num_or_sections=[1024, 154], axis=1
# )

attn_output = paddle.nn.functional.linear(
attn_output, self.to_out_linear[i].weight, self.to_out_linear[i].bias
)

if not context_pre_only:
if model_parallel_size > 1:
attn_output = self.to_out_linear_mp[i](attn_output)
context_attn_output = self.to_add_out_linear_mp[i](context_attn_output)
else:
attn_output = self.to_out_linear[i](attn_output)
context_attn_output = self.to_add_out_linear[i](context_attn_output)

hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06
)

# ffn1
ffn_output = self.ffn1[i](norm_hidden_states)
ffn_output = F.gelu(ffn_output, approximate=True)
ffn_output = self.ffn2[i](ffn_output)
if model_parallel_size > 1:
ffn_output = self.ffn1_mp[i](norm_hidden_states)
ffn_output = F.gelu(ffn_output, approximate=True)
ffn_output = self.ffn2_mp[i](ffn_output)
else:
ffn_output = self.ffn1[i](norm_hidden_states)
ffn_output = F.gelu(ffn_output, approximate=True)
ffn_output = self.ffn2[i](ffn_output)

if context_pre_only:
ffn_output = gate_mlp.unsqueeze(1) * ffn_output
Expand All @@ -149,9 +186,14 @@ def forward(self, hidden_states, encoder_hidden_states, temb):
encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06
)

context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states)
context_ffn_output = F.gelu(context_ffn_output, approximate=True)
context_ffn_output = self.ffn2_context[i](context_ffn_output)
if model_parallel_size > 1:
context_ffn_output = self.ffn1_context_mp[i](norm_encoder_hidden_states)
context_ffn_output = F.gelu(context_ffn_output, approximate=True)
context_ffn_output = self.ffn2_context_mp[i](context_ffn_output)
else:
context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states)
context_ffn_output = F.gelu(context_ffn_output, approximate=True)
context_ffn_output = self.ffn2_context[i](context_ffn_output)

last_context_ffn_output = context_ffn_output
last_context_hidden_states = encoder_hidden_states
Expand Down
28 changes: 28 additions & 0 deletions ppdiffusers/ppdiffusers/models/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,31 @@ def custom_modify_weight(cls, state_dict):
],
axis=-1,
)
from paddle.distributed.fleet.utils import recompute
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
rank_id = dist.get_rank()
for mp_name in ["ffn2","to_out_linear"]:
tmpc = paddle.split(state_dict[f"simplified_sd3.{mp_name}.{i}.weight"],2,axis=0)
state_dict[f"simplified_sd3.{mp_name}_mp.{i}.weight"] = tmpc[rank_id]
state_dict[f"simplified_sd3.{mp_name}_mp.{i}.bias"] = state_dict[f"simplified_sd3.{mp_name}.{i}.bias"]
if i<23:
tmpc = paddle.split(state_dict[f"simplified_sd3.to_add_out_linear.{i}.weight"],2,axis=0)
state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.weight"] = tmpc[rank_id]
state_dict[f"simplified_sd3.to_add_out_linear_mp.{i}.bias"] = state_dict[f"simplified_sd3.to_add_out_linear.{i}.bias"]
tmpf = paddle.split(state_dict[f"simplified_sd3.ffn2_context.{i}.weight"],2,axis=0)
state_dict[f"simplified_sd3.ffn2_context_mp.{i}.weight"] = tmpf[rank_id]
state_dict[f"simplified_sd3.ffn2_context_mp.{i}.bias"] = state_dict[f"simplified_sd3.ffn2_context.{i}.bias"]
for placeholder in ["weight", "bias"]:
tmpf1 = paddle.split(state_dict[f"simplified_sd3.ffn1_context.{i}.{placeholder}"],2,axis=-1)
state_dict[f"simplified_sd3.ffn1_context_mp.{i}.{placeholder}"] = tmpf1[rank_id]
for placeholder in ["weight", "bias"]:
tmp = paddle.split(state_dict[f"simplified_sd3.ffn1.{i}.{placeholder}"],2,axis=-1)
state_dict[f"simplified_sd3.ffn1_mp.{i}.{placeholder}"] = tmp[rank_id]
for placeholder1 in ["", "e"]:
tmpq = paddle.split(state_dict[f"simplified_sd3.{placeholder1}qkv.{i}.{placeholder}"],6,axis=-1)
tmpqp = [
paddle.concat([tmpq[0], tmpq[2],tmpq[4]],axis=-1),
paddle.concat([tmpq[1], tmpq[3],tmpq[5]],axis=-1)
]
state_dict[f"simplified_sd3.{placeholder1}qkv_mp.{i}.{placeholder}"] = tmpqp[rank_id]
Original file line number Diff line number Diff line change
Expand Up @@ -800,22 +800,22 @@ def __call__(
latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
if self.inference_optimize_bp and self.do_classifier_free_guidance:
latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0)
timestep_input ,timestep_ = paddle.split(timestep,2,axis=0)
prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0)
pooled_prompt_embeds_input ,pooled_prompt_embeds_ = paddle.split(pooled_prompt_embeds,2,axis=0)
# if self.inference_optimize_bp and self.do_classifier_free_guidance:
# latent_input ,latent_model_input_ = paddle.split(latent_model_input,2,axis=0)
# timestep_input ,timestep_ = paddle.split(timestep,2,axis=0)
# prompt_embeds_input ,prompt_embeds_ = paddle.split(prompt_embeds,2,axis=0)
# pooled_prompt_embeds_input ,pooled_prompt_embeds_ = paddle.split(pooled_prompt_embeds,2,axis=0)

dist.scatter(latent_input,[latent_input,latent_model_input_])
dist.scatter(timestep_input,[timestep_input,timestep_])
dist.scatter(prompt_embeds_input,[prompt_embeds_input,prompt_embeds_])
dist.scatter(pooled_prompt_embeds_input,[pooled_prompt_embeds_input,pooled_prompt_embeds_])

else:
latent_input = latent_model_input
timestep_input = timestep
prompt_embeds_input = prompt_embeds
pooled_prompt_embeds_input = pooled_prompt_embeds
# dist.scatter(latent_input,[latent_input,latent_model_input_])
# dist.scatter(timestep_input,[timestep_input,timestep_])
# dist.scatter(prompt_embeds_input,[prompt_embeds_input,prompt_embeds_])
# dist.scatter(pooled_prompt_embeds_input,[pooled_prompt_embeds_input,pooled_prompt_embeds_])

# else:
latent_input = latent_model_input
timestep_input = timestep
prompt_embeds_input = prompt_embeds
pooled_prompt_embeds_input = pooled_prompt_embeds

model_output = self.transformer(
hidden_states=latent_input,
Expand All @@ -832,13 +832,13 @@ def __call__(
else:
output = model_output[0]

if self.inference_optimize_bp:
tmp_shape = output.shape
tmp_shape[0] *=2
noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype)
dist.all_gather(noise_pred,output)
else:
noise_pred = output
# if self.inference_optimize_bp:
# tmp_shape = output.shape
# tmp_shape[0] *=2
# noise_pred = paddle.zeros(tmp_shape,dtype=output.dtype)
# dist.all_gather(noise_pred,output)
# else:
noise_pred = output



Expand Down