Skip to content

Commit

Permalink
bert attention mask (#1934)
Browse files Browse the repository at this point in the history
* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
lz1998 and LaurentMazare authored Aug 1, 2024
1 parent 24d54d0 commit 4a52aeb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
12 changes: 10 additions & 2 deletions candle-examples/examples/bert/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ fn main() -> Result<()> {
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids, &token_type_ids)?;
let ys = model.forward(&token_ids, &token_type_ids, None)?;
if idx == 0 {
println!("{ys}");
}
Expand Down Expand Up @@ -163,11 +163,19 @@ fn main() -> Result<()> {
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), device)?)
})
.collect::<Result<Vec<_>>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids, &token_type_ids)?;
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
Expand Down
49 changes: 32 additions & 17 deletions candle-transformers/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,8 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
}

impl Module for BertSelfAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
let key_layer = self.key.forward(hidden_states)?;
Expand All @@ -245,6 +243,7 @@ impl Module for BertSelfAttention {

let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
Expand Down Expand Up @@ -307,12 +306,10 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
}

impl Module for BertAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states)?;
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
}
Expand Down Expand Up @@ -398,12 +395,10 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}

impl Module for BertLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states)?;
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
// TODO: Support cross-attention?
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
// TODO: Support something similar to `apply_chunking_to_forward`?
Expand All @@ -429,15 +424,13 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
}

impl Module for BertEncoder {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states)?
hidden_states = layer.forward(&hidden_states, attention_mask)?
}
Ok(hidden_states)
}
Expand Down Expand Up @@ -481,10 +474,32 @@ impl BertModel {
})
}

pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;
let attention_mask = match attention_mask {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
Ok(sequence_output)
}
}

fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
let attention_mask = match attention_mask.rank() {
3 => attention_mask.unsqueeze(1)?,
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
}
12 changes: 11 additions & 1 deletion candle-wasm-examples/bert/src/bin/m.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,21 @@ impl Model {
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let attention_mask: Vec<Tensor> = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
console_log!("running inference on batch {:?}", token_ids.shape());
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
let embeddings = self
.bert
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
Expand Down

0 comments on commit 4a52aeb

Please sign in to comment.