Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About multi-head attention in attention is all you need, thanks. #19

Open
sonrisa07 opened this issue Oct 14, 2023 · 2 comments
Open

About multi-head attention in attention is all you need, thanks. #19

sonrisa07 opened this issue Oct 14, 2023 · 2 comments

Comments

@sonrisa07
Copy link

sonrisa07 commented Oct 14, 2023

Hello, author. I am sincerely that you can answer me when you saw.
I urgently want to realize why there are Q, K, V as input in multi-head attention and then feed them into the three linear of each head respectively? Does the three linear represent w_q, w_k and w_v of each head? If so, the embedding matrix needs to be convert to Q, K and V and then be convert to Q_i, K_i and V_i passing by w_q, w_k and w_v of certain head. The embedding matrix will go through two transformations.
I have seen several realizations including yours and you all directly feed the embedding matrix into the three linear of each head.
How is it to achieve? thanks for your help.

@adhiraj2001
Copy link

adhiraj2001 commented Oct 15, 2023

As far as I understand, your doubt is that why Q, K, V is not going through n_head linear transformations to extract Q_i, Q_i and V_i corresponding to each head ?
=>
The answer to that is to avoid significant growth of computational cost and parametrization cost, we set d_q = d_k = d_v = d_model / n_head. [1]
This is what the split() function in MultiHeadAttention does, and then concat() function essentially is a weighted combination of these heads, just like in the paper.

You can see a similar implementation in PyTorch source code as well. [2]

Anyone else reading this, please correct me if I am wrong or if there are some others benefits/reasons of using this implementation.

EDIT:
It's clearly mentioned in the paper as well:

In this work we employ h = 8 parallel attention layers, or heads. For each of these we use d_k = d_v = d_model / h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

For the most faithful implementations of research papers, you should also check out labml.ai annotated pytorch implementations repository. [3]

[1] https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
[2] https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention
[3] http://nlp.seas.harvard.edu/annotated-transformer/

@sonrisa07
Copy link
Author

As far as I understand, your doubt is that why Q, K, V is not going through n_head linear transformations to extract Q_i, Q_i and V_i corresponding to each head ? => The answer to that is to avoid significant growth of computational cost and parametrization cost, we set d_q = d_k = d_v = d_model / n_head. [1] This is what the split() function in MultiHeadAttention does, and then concat() function essentially is a weighted combination of these heads, just like in the paper.

You can see a similar implementation in PyTorch source code as well. [2]

Anyone else reading this, please correct me if I am wrong or if there are some others benefits/reasons of using this implementation.

EDIT: It's clearly mentioned in the paper as well:

In this work we employ h = 8 parallel attention layers, or heads. For each of these we use d_k = d_v = d_model / h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

For the most faithful implementations of research papers, you should also check out labml.ai annotated pytorch implementations repository. [3]

[1] https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html [2] https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention [3] http://nlp.seas.harvard.edu/annotated-transformer/

Thank you for your comment, but it doesn't address my question. For instance, consider a sequence, and we need to produce its embedding matrix, named X. Then, it is sent to every head and multiplied by W_q, W_k, and W_v, respectively. Now, each head generates its corresponding Q, K, and V.

However, before entering each linear layer in every head, the paper's multi-head attention illustration shows Q, K, and V instead of X, X, and X correspondingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants