From e83363c0ad6b557049ff141476a11b61e8dd7fc5 Mon Sep 17 00:00:00 2001 From: DimitriTimoz Date: Wed, 6 Nov 2024 18:37:01 +0100 Subject: [PATCH 1/2] Implement BertPooler --- candle-examples/examples/bert/main.rs | 2 +- candle-examples/examples/splade/main.rs | 2 +- candle-transformers/src/models/bert.rs | 57 +++++++++++++++++++++++-- candle-wasm-examples/bert/src/bin/m.rs | 2 +- 4 files changed, 57 insertions(+), 6 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index cb80f6eb6d..00720e45ae 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -88,7 +88,7 @@ impl Args { if self.approximate_gelu { config.hidden_act = HiddenAct::GeluApproximate; } - let model = BertModel::load(vb, &config)?; + let model = BertModel::load(vb, &config, false)?; Ok((model, tokenizer)) } } diff --git a/candle-examples/examples/splade/main.rs b/candle-examples/examples/splade/main.rs index aa4c60ac41..584b33e722 100644 --- a/candle-examples/examples/splade/main.rs +++ b/candle-examples/examples/splade/main.rs @@ -92,7 +92,7 @@ fn main() -> Result<()> { println!("Loading weights from pytorch_model.bin"); VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap() }; - let model = BertForMaskedLM::load(vb, &config)?; + let model = BertForMaskedLM::load(vb, &config, false)?; if let Some(prompt) = args.prompt { let tokenizer = tokenizer diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index bdc0385deb..8450de6738 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -11,6 +11,7 @@ pub enum HiddenAct { Gelu, GeluApproximate, Relu, + Tanh, } struct HiddenActLayer { @@ -31,6 +32,7 @@ impl HiddenActLayer { HiddenAct::Gelu => xs.gelu_erf(), HiddenAct::GeluApproximate => xs.gelu(), HiddenAct::Relu => xs.relu(), + HiddenAct::Tanh => xs.tanh(), } } } @@ -436,16 +438,45 @@ impl BertEncoder { } } +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L654 +struct BertPooler { + dense: Linear, + activation: HiddenActLayer, + span: tracing::Span, +} + +impl BertPooler { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let activation = HiddenActLayer::new(HiddenAct::Tanh); + Ok(Self { + dense, + activation, + span: tracing::span!(tracing::Level::TRACE, "pooler"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let first_token_tensor = hidden_states.narrow(1, 0, 1)?; + let pooled_output = self + .activation + .forward(&self.dense.forward(&first_token_tensor)?)?; + Ok(pooled_output) + } +} + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 pub struct BertModel { embeddings: BertEmbeddings, encoder: BertEncoder, + pooler: Option, pub device: Device, span: tracing::Span, } impl BertModel { - pub fn load(vb: VarBuilder, config: &Config) -> Result { + pub fn load(vb: VarBuilder, config: &Config, add_pooling_layer: bool) -> Result { let (embeddings, encoder) = match ( BertEmbeddings::load(vb.pp("embeddings"), config), BertEncoder::load(vb.pp("encoder"), config), @@ -466,9 +497,29 @@ impl BertModel { } } }; + + let pooler = if add_pooling_layer { + match BertPooler::load(vb.pp("pooler"), config) { + Ok(pooler) => Some(pooler), + Err(err) => { + if let Some(model_type) = &config.model_type { + Some(BertPooler::load( + vb.pp(format!("{model_type}.pooler")), + config, + )?) + } else { + return Err(err); + } + } + } + } else { + None + }; + Ok(Self { embeddings, encoder, + pooler, device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "model"), }) @@ -583,8 +634,8 @@ pub struct BertForMaskedLM { } impl BertForMaskedLM { - pub fn load(vb: VarBuilder, config: &Config) -> Result { - let bert = BertModel::load(vb.pp("bert"), config)?; + pub fn load(vb: VarBuilder, config: &Config, add_pooling_layer: bool) -> Result { + let bert = BertModel::load(vb.pp("bert"), config, add_pooling_layer)?; let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?; Ok(Self { bert, cls }) } diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs index 9e5cf913ad..5ef6e8886d 100644 --- a/candle-wasm-examples/bert/src/bin/m.rs +++ b/candle-wasm-examples/bert/src/bin/m.rs @@ -22,7 +22,7 @@ impl Model { let config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; - let bert = BertModel::load(vb, &config)?; + let bert = BertModel::load(vb, &config, false)?; Ok(Self { bert, tokenizer }) } From c2b65edfba0caed690ec58cdb0cb2b599ea8a854 Mon Sep 17 00:00:00 2001 From: DimitriTimoz Date: Wed, 6 Nov 2024 19:31:41 +0100 Subject: [PATCH 2/2] Bert add conditional Pooling --- candle-transformers/src/models/bert.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 8450de6738..70233e1f01 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -540,7 +540,11 @@ impl BertModel { // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995 let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?; let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?; - Ok(sequence_output) + if let Some(pooler) = &self.pooler { + pooler.forward(&sequence_output) + } else { + Ok(sequence_output) + } } }