From d232e132f6af552c351bb046a38df4bce009c8aa Mon Sep 17 00:00:00 2001 From: Czxck001 <10724409+Czxck001@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:19:07 -0700 Subject: [PATCH] Support sd3.5 medium and MMDiT-X (#2587) * extract attn out of joint_attn * further adjust attn and joint_attn * add mmdit-x support * support sd3.5-medium in the example * update README.md --- .../examples/stable-diffusion-3/README.md | 20 +- .../examples/stable-diffusion-3/main.rs | 44 +++- .../src/models/mmdit/blocks.rs | 191 ++++++++++++++++-- candle-transformers/src/models/mmdit/model.rs | 49 ++++- 4 files changed, 269 insertions(+), 35 deletions(-) diff --git a/candle-examples/examples/stable-diffusion-3/README.md b/candle-examples/examples/stable-diffusion-3/README.md index 52ebfa55e1..adae1b566e 100644 --- a/candle-examples/examples/stable-diffusion-3/README.md +++ b/candle-examples/examples/stable-diffusion-3/README.md @@ -1,8 +1,8 @@ -# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium +# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3/3.5 ![](assets/stable-diffusion-3.jpg) -*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k* +*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*, generated by Stable Diffusion 3 Medium Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture. @@ -10,9 +10,17 @@ Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion - [research paper](https://arxiv.org/pdf/2403.03206) - [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium) +Stable Diffusion 3.5 is a family of text-to-image models with latest improvements: +- [announcement blog post](https://stability.ai/news/introducing-stable-diffusion-3-5) + +It has three variants: +- [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) @ 8.1b params, with scaled and slightly modified MMDiT architecture. +- [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) distilled version that enables 4-step inference. +- [Stable Diffusion 3.5 Medium](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) @ 2.5b params, with improved MMDiT-X architecture. + ## Getting access to the weights -The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting [the repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account. +The weights of Stable Diffusion 3/3.5 is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the repos on HuggingFace Hub to gain access to the weights for your HuggingFace account. To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli): @@ -27,10 +35,12 @@ On the first run, the weights will be automatically downloaded from the Huggingf ```shell cargo run --example stable-diffusion-3 --release --features=cuda -- \ - --height 1024 --width 1024 \ + --which 3-medium --height 1024 --width 1024 \ --prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k' ``` +To use different models, changed the value of `--which` option. (Possible values: `3-medium`, `3.5-large`, `3.5-large-turbo` and `3.5-medium`). + To display other options available, ```shell @@ -45,7 +55,7 @@ cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- - ## Performance Benchmark -Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). +Below benchmark is done with Stable Diffusion 3 Medium by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). [candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc). diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index d0bf4bb803..31d3fc4234 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -19,13 +19,15 @@ enum Which { V3_5Large, #[value(name = "3.5-large-turbo")] V3_5LargeTurbo, + #[value(name = "3.5-medium")] + V3_5Medium, } impl Which { fn is_3_5(&self) -> bool { match self { Self::V3Medium => false, - Self::V3_5Large | Self::V3_5LargeTurbo => true, + Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true, } } } @@ -117,36 +119,59 @@ fn main() -> Result<()> { let default_inference_steps = match which { Which::V3_5Large => 28, Which::V3_5LargeTurbo => 4, + Which::V3_5Medium => 28, Which::V3Medium => 28, }; let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps); let default_cfg_scale = match which { Which::V3_5Large => 4.0, Which::V3_5LargeTurbo => 1.0, + Which::V3_5Medium => 4.0, Which::V3Medium => 4.0, }; let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale); let api = hf_hub::api::sync::Api::new()?; let (mmdit_config, mut triple, vb) = if which.is_3_5() { - let sai_repo = { + let sai_repo_for_text_encoders = { + let name = match which { + Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", + Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", + + // Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually + // placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo. + // To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors + // under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions + // to get the monolithic text encoders. This is not a trivial task. + // Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium, + // which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors. + // so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository. + // TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders. + Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large", + Which::V3Medium => unreachable!(), + }; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let sai_repo_for_mmdit = { let name = match which { Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", + Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium", Which::V3Medium => unreachable!(), }; api.repo(hf_hub::Repo::model(name.to_string())) }; - let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?; - let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?; - let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?; + let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?; + let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?; + let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?; let model_file = { let model_file = match which { Which::V3_5Large => "sd3.5_large.safetensors", Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors", + Which::V3_5Medium => "sd3.5_medium.safetensors", Which::V3Medium => unreachable!(), }; - sai_repo.get(model_file)? + sai_repo_for_mmdit.get(model_file)? }; let triple = StableDiffusion3TripleClipWithTokenizer::new_split( &clip_g_file, @@ -157,7 +182,12 @@ fn main() -> Result<()> { let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)? }; - (MMDiTConfig::sd3_5_large(), triple, vb) + match which { + Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb), + Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb), + Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb), + Which::V3Medium => unreachable!(), + } } else { let sai_repo = { let name = "stabilityai/stable-diffusion-3-medium"; diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs index a1777f915b..912e249835 100644 --- a/candle-transformers/src/models/mmdit/blocks.rs +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -36,7 +36,6 @@ impl Module for LayerNormNoAffine { impl DiTBlock { pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { - // {'hidden_size': 1536, 'num_heads': 24} let norm1 = LayerNormNoAffine::new(1e-6); let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; let norm2 = LayerNormNoAffine::new(1e-6); @@ -103,6 +102,117 @@ impl DiTBlock { } } +pub struct SelfAttnModulateIntermediates { + gate_msa: Tensor, + shift_mlp: Tensor, + scale_mlp: Tensor, + gate_mlp: Tensor, + gate_msa2: Tensor, +} + +pub struct SelfAttnDiTBlock { + norm1: LayerNormNoAffine, + attn: AttnProjections, + attn2: AttnProjections, + norm2: LayerNormNoAffine, + mlp: Mlp, + ada_ln_modulation: nn::Sequential, +} + +impl SelfAttnDiTBlock { + pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + let norm1 = LayerNormNoAffine::new(1e-6); + let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; + let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?; + let norm2 = LayerNormNoAffine::new(1e-6); + let mlp_ratio = 4; + let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?; + let n_mods = 9; + let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear( + hidden_size, + n_mods * hidden_size, + vb.pp("adaLN_modulation.1"), + )?); + + Ok(Self { + norm1, + attn, + attn2, + norm2, + mlp, + ada_ln_modulation, + }) + } + + pub fn pre_attention( + &self, + x: &Tensor, + c: &Tensor, + ) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> { + let modulation = self.ada_ln_modulation.forward(c)?; + let chunks = modulation.chunk(9, D::Minus1)?; + let ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + shift_msa2, + scale_msa2, + gate_msa2, + ) = ( + chunks[0].clone(), + chunks[1].clone(), + chunks[2].clone(), + chunks[3].clone(), + chunks[4].clone(), + chunks[5].clone(), + chunks[6].clone(), + chunks[7].clone(), + chunks[8].clone(), + ); + + let norm_x = self.norm1.forward(x)?; + let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?; + let qkv = self.attn.pre_attention(&modulated_x)?; + + let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?; + let qkv2 = self.attn2.pre_attention(&modulated_x2)?; + + Ok(( + qkv, + qkv2, + SelfAttnModulateIntermediates { + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + gate_msa2, + }, + )) + } + + pub fn post_attention( + &self, + attn: &Tensor, + attn2: &Tensor, + x: &Tensor, + mod_interm: &SelfAttnModulateIntermediates, + ) -> Result { + let attn_out = self.attn.post_attention(attn)?; + let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?; + let attn_out2 = self.attn2.post_attention(attn2)?; + let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?; + + let norm_x = self.norm2.forward(&x)?; + let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?; + let mlp_out = self.mlp.forward(&modulated_x)?; + let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?; + Ok(x) + } +} + pub struct QkvOnlyDiTBlock { norm1: LayerNormNoAffine, attn: QkvOnlyAttnProjections, @@ -190,14 +300,18 @@ fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result { shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?) } -pub struct JointBlock { +pub trait JointBlock { + fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>; +} + +pub struct MMDiTJointBlock { x_block: DiTBlock, context_block: DiTBlock, num_heads: usize, use_flash_attn: bool, } -impl JointBlock { +impl MMDiTJointBlock { pub fn new( hidden_size: usize, num_heads: usize, @@ -214,8 +328,10 @@ impl JointBlock { use_flash_attn, }) } +} - pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { +impl JointBlock for MMDiTJointBlock { + fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; let (context_attn, x_attn) = @@ -228,6 +344,49 @@ impl JointBlock { } } +pub struct MMDiTXJointBlock { + x_block: SelfAttnDiTBlock, + context_block: DiTBlock, + num_heads: usize, + use_flash_attn: bool, +} + +impl MMDiTXJointBlock { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result { + let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; + let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; + + Ok(Self { + x_block, + context_block, + num_heads, + use_flash_attn, + }) + } +} + +impl JointBlock for MMDiTXJointBlock { + fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { + let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; + let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?; + let (context_attn, x_attn) = + joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; + let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?; + let context_out = + self.context_block + .post_attention(&context_attn, context, &context_interm)?; + let x_out = self + .x_block + .post_attention(&x_attn, &x_attn2, x, &x_interm)?; + Ok((context_out, x_out)) + } +} + pub struct ContextQkvOnlyJointBlock { x_block: DiTBlock, context_block: QkvOnlyDiTBlock, @@ -309,26 +468,30 @@ fn joint_attn( v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?, }; - let (batch_size, seqlen, _) = qkv.q.dims3()?; + let seqlen = qkv.q.dim(1)?; + let attn = attn(&qkv, num_heads, use_flash_attn)?; + let context_qkv_seqlen = context_qkv.q.dim(1)?; + let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?; + let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?; + + Ok((context_attn, x_attn)) +} + +fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result { + let batch_size = qkv.q.dim(0)?; + let seqlen = qkv.q.dim(1)?; let qkv = Qkv { q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?, k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?, - v: qkv.v, + v: qkv.v.clone(), }; let headdim = qkv.q.dim(D::Minus1)?; let softmax_scale = 1.0 / (headdim as f64).sqrt(); - let attn = if use_flash_attn { flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)? } else { flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)? }; - - let attn = attn.reshape((batch_size, seqlen, ()))?; - let context_qkv_seqlen = context_qkv.q.dim(1)?; - let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?; - let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?; - - Ok((context_attn, x_attn)) + attn.reshape((batch_size, seqlen, ())) } diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 5b5c90b0c3..c7b4deedb2 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -1,10 +1,15 @@ -// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206). +// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206), +// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) // This follows the implementation of the MMDiT model in the ComfyUI repository. // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1 +// with MMDiT-X support following the Stability-AI/sd3.5 repository. +// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1 use candle::{Module, Result, Tensor, D}; use candle_nn as nn; -use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock}; +use super::blocks::{ + ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock, +}; use super::embedding::{ PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder, }; @@ -37,6 +42,20 @@ impl Config { } } + pub fn sd3_5_medium() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 24, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 384, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } + pub fn sd3_5_large() -> Self { Self { patch_size: 2, @@ -138,7 +157,7 @@ impl MMDiT { } pub struct MMDiTCore { - joint_blocks: Vec, + joint_blocks: Vec>, context_qkv_only_joint_block: ContextQkvOnlyJointBlock, final_layer: FinalLayer, } @@ -155,12 +174,24 @@ impl MMDiTCore { ) -> Result { let mut joint_blocks = Vec::with_capacity(depth - 1); for i in 0..depth - 1 { - joint_blocks.push(JointBlock::new( - hidden_size, - num_heads, - use_flash_attn, - vb.pp(format!("joint_blocks.{}", i)), - )?); + let joint_block_vb_pp = format!("joint_blocks.{}", i); + let joint_block: Box = + if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) { + Box::new(MMDiTXJointBlock::new( + hidden_size, + num_heads, + use_flash_attn, + vb.pp(&joint_block_vb_pp), + )?) + } else { + Box::new(MMDiTJointBlock::new( + hidden_size, + num_heads, + use_flash_attn, + vb.pp(&joint_block_vb_pp), + )?) + }; + joint_blocks.push(joint_block); } Ok(Self {