You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
I've spotted a mistake in the Vision Transformer examples in Keras.io [3,4,5,6,7].
In all five of the examples below, to build the ViT architecture, the authors use a single hyper-parameter named projection_dim, which is used both as the model's hidden dimension, and as the dimension for queries, keys, and values, in the multi-head attention layer. These two hyper-parameters they shouldn't be the same. However, according to [1], they are connected:
hidden dimension = number of heads * qkv dimension
One simple way to verify this issue, is to calculate the total number of trainable parameters of the model.
Using the architecture from the examples in Keras.io, and setting the same hyper-parameters with vision transformer base, the model has only 15 million parameters (while the Vision Transformer Base has 86 million [2]).
To fix this issue:
a hidden dimension parameter can be defined as:
hidden_dim = projection_dim * num_heads
The encoded patches should be projected in the hidden dimension, instead of the projection_dim:
encoded_patches = PatchEncoder(num_patches, hidden_dim)(patches)
The transformer_units should also use the hidden dimension:
transformer_units = [hidden_dim * 2, hidden_dim, ]
Then, if the same hyper-parameters used as in the original paper, the number of trainable parameters will be the same, as in the ViT base.
I understand that the authors may have used alternative versions of the original model, but this particular modification, can change significantly the behaviour of the model.
If you'll need any further information, please let me know.
Issue Type
Documentation Bug
Source
source
Keras Version
2.14
Custom Code
Yes
OS Platform and Distribution
Ubuntu 22.04
Python version
3.10
GPU model and memory
Nvidia RTX4070 (12GB)
Current Behavior?
Hi,
I've spotted a mistake in the Vision Transformer examples in Keras.io [3,4,5,6,7].
In all five of the examples below, to build the ViT architecture, the authors use a single hyper-parameter named projection_dim, which is used both as the model's hidden dimension, and as the dimension for queries, keys, and values, in the multi-head attention layer. These two hyper-parameters they shouldn't be the same. However, according to [1], they are connected:
hidden dimension = number of heads * qkv dimension
One simple way to verify this issue, is to calculate the total number of trainable parameters of the model.
Using the architecture from the examples in Keras.io, and setting the same hyper-parameters with vision transformer base, the model has only 15 million parameters (while the Vision Transformer Base has 86 million [2]).
To fix this issue:
a hidden dimension parameter can be defined as:
hidden_dim = projection_dim * num_heads
The encoded patches should be projected in the hidden dimension, instead of the projection_dim:
encoded_patches = PatchEncoder(num_patches, hidden_dim)(patches)
The transformer_units should also use the hidden dimension:
transformer_units = [hidden_dim * 2, hidden_dim, ]
Then, if the same hyper-parameters used as in the original paper, the number of trainable parameters will be the same, as in the ViT base.
I understand that the authors may have used alternative versions of the original model, but this particular modification, can change significantly the behaviour of the model.
If you'll need any further information, please let me know.
Best wishes,
Angelos
[1] see table 3 in the original paper: https://arxiv.org/pdf/1706.03762
[2] https://arxiv.org/pdf/2010.11929
[3] https://keras.io/examples/vision/image_classification_with_vision_transformer/
[4] https://keras.io/examples/vision/vit_small_ds/
[5] https://keras.io/examples/vision/object_detection_using_vision_transformer/
[6] https://keras.io/examples/vision/token_learner/
[7] https://keras.io/examples/vision/vit_small_ds/
Standalone code to reproduce the issue or tutorial link
Relevant log output
No response
The text was updated successfully, but these errors were encountered: