From 2be9bd211e34333b605695242896903231ab26da Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 4 Aug 2024 18:52:40 +0100 Subject: [PATCH] Support for mistral-nemo. (#2396) --- candle-examples/examples/mistral/main.rs | 21 ++++++++++++++------- candle-transformers/src/models/mistral.rs | 17 ++++++++++++----- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 39cf61422d..66265488a0 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -149,6 +149,10 @@ enum Which { Mistral7bInstructV02, #[value(name = "7b-maths-v0.1")] Mathstral7bV01, + #[value(name = "nemo-2407")] + MistralNemo2407, + #[value(name = "nemo-instruct-2407")] + MistralNemoInstruct2407, } #[derive(Parser, Debug)] @@ -263,13 +267,16 @@ fn main() -> Result<()> { } "lmz/candle-mistral".to_string() } else { - match args.which { - Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(), - Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(), - Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(), - Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(), - Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1".to_string(), - } + let name = match args.which { + Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1", + Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2", + Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1", + Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2", + Which::Mathstral7bV01 => "mistralai/mathstral-7B-v0.1", + Which::MistralNemo2407 => "mistralai/Mistral-Nemo-Base-2407", + Which::MistralNemoInstruct2407 => "mistralai/Mistral-Nemo-Instruct-2407", + }; + name.to_string() } } }; diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs index 1cb55f9e61..7e3b21c92f 100644 --- a/candle-transformers/src/models/mistral.rs +++ b/candle-transformers/src/models/mistral.rs @@ -15,6 +15,7 @@ pub struct Config { pub intermediate_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, + pub head_dim: Option, pub num_key_value_heads: usize, pub hidden_act: Activation, pub max_position_embeddings: usize, @@ -34,6 +35,7 @@ impl Config { intermediate_size: 14336, num_hidden_layers: 32, num_attention_heads: 32, + head_dim: None, num_key_value_heads: 8, hidden_act: Activation::Silu, max_position_embeddings: 32768, @@ -53,6 +55,7 @@ impl Config { intermediate_size: 14336, num_hidden_layers: 32, num_attention_heads: 32, + head_dim: None, num_key_value_heads: 8, hidden_act: Activation::Silu, max_position_embeddings: 32768, @@ -71,6 +74,7 @@ impl Config { intermediate_size: 14336, num_hidden_layers: 32, num_attention_heads: 32, + head_dim: None, num_key_value_heads: 8, hidden_act: Activation::Silu, max_position_embeddings: 32768, @@ -80,6 +84,11 @@ impl Config { use_flash_attn, } } + + fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) + } } #[derive(Debug, Clone)] @@ -91,7 +100,7 @@ struct RotaryEmbedding { impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let rope_theta = cfg.rope_theta as f32; - let dim = cfg.hidden_size / cfg.num_attention_heads; + let dim = cfg.head_dim(); let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) @@ -183,7 +192,6 @@ struct Attention { num_kv_heads: usize, num_kv_groups: usize, head_dim: usize, - hidden_size: usize, rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, use_flash_attn: bool, @@ -195,7 +203,7 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; - let head_dim = hidden_sz / num_heads; + let head_dim = cfg.head_dim(); let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; @@ -209,7 +217,6 @@ impl Attention { num_kv_heads, num_kv_groups, head_dim, - hidden_size: hidden_sz, rotary_emb, kv_cache: None, use_flash_attn: cfg.use_flash_attn, @@ -277,7 +284,7 @@ impl Attention { }; attn_output .transpose(1, 2)? - .reshape((b_sz, q_len, self.hidden_size))? + .reshape((b_sz, q_len, self.num_heads * self.head_dim))? .apply(&self.o_proj) }