From 7513a5e005bfa7e205345aaeeb6f660cf178a598 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 28 Jul 2023 18:31:28 +0100 Subject: [PATCH] Line-up the llama implementation with the python-transformers one. (#271) * Line-up the llama implementation with the python-transformers one. * Also lineup the multiprocess version. --- candle-examples/examples/llama/model.rs | 71 ++++++++----------- .../examples/llama_multiprocess/model.rs | 2 +- 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index c4d33f0b8..efb9aeef4 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -14,6 +14,7 @@ pub struct Config { pub n_embd: usize, pub n_key_value_head: usize, pub use_flash_attn: bool, + pub rms_norm_eps: f64, } impl Config { @@ -27,6 +28,7 @@ impl Config { n_embd: 4096, n_key_value_head: 32, use_flash_attn, + rms_norm_eps: 1e-5, } } } @@ -102,16 +104,13 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { struct RmsNorm { scale: Tensor, + eps: f64, } impl RmsNorm { - fn load(size: usize, vb: VarBuilder) -> Result { + fn load(size: usize, eps: f64, vb: VarBuilder) -> Result { let scale = vb.get(size, "weight")?; - Ok(Self::new(scale)) - } - - fn new(scale: Tensor) -> Self { - Self { scale } + Ok(Self { scale, eps }) } fn forward(&self, x: &Tensor) -> Result { @@ -121,7 +120,7 @@ impl RmsNorm { let (b_sz, seq_len, hidden_size) = x.dims3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; + let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; let size = self.scale.dims1()?; let scale = self .scale @@ -292,14 +291,6 @@ struct Mlp { } impl Mlp { - fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { - Self { - c_fc1, - c_fc2, - c_proj, - } - } - fn forward(&self, x: &Tensor) -> Result { let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; self.c_proj.forward(&x) @@ -311,7 +302,11 @@ impl Mlp { let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; - Ok(Self::new(c_fc1, c_fc2, c_proj)) + Ok(Self { + c_fc1, + c_fc2, + c_proj, + }) } } @@ -323,15 +318,6 @@ struct Block { } impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { - Self { - rms_1, - attn, - rms_2, - mlp, - } - } - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { let residual = x; let x = self.rms_1.forward(x)?; @@ -344,15 +330,18 @@ impl Block { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; - let post_attention_layernorm = - RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?; - Ok(Self::new( - input_layernorm, + let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::load( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + rms_1, attn, - post_attention_layernorm, + rms_2, mlp, - )) + }) } } @@ -364,15 +353,6 @@ pub struct Llama { } impl Llama { - fn new(wte: Embedding, blocks: Vec, ln_f: RmsNorm, lm_head: Linear) -> Self { - Self { - wte, - blocks, - ln_f, - lm_head, - } - } - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; @@ -388,11 +368,16 @@ impl Llama { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; + let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layer) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .collect(); - Ok(Self::new(wte, blocks, norm, lm_head)) + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + }) } } diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index ae2ef3e7a..6df7c1b0b 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -225,7 +225,7 @@ impl RmsNorm { let (b_sz, seq_len, hidden_size) = x.shape().dims3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; + let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().dims1()?; let scale = self .scale