The main
branch of this repository aims to reproduce Better plain ViT baselines for ImageNet-1k in pytorch,
in particular the 76.7% top-1 validation set accuracy of the Head: MLP → linear variant after 90 epochs. This variant is no inferior than the default,
and personally I have better experience with simpler prediction head. The changes I have made
to the big_vision reference implementation in my attempts to make the results converge
reside in the grad_accum_wandb branch. In the rest of this README I would like to highlight
some of the discrepancies I resolved.
In Better plain ViT baselines for ImageNet-1k only the first 99% of the training data is used for training while the
remaining 1% is used for minival "to encourage the community to stop selecting design choices on the validation (de-facto test) set". This however is
difficult to reproduce with torchvision.datasets
since datasets.ImageNet()
is ordered by class label, unlike tfds
where the ordering is somewhat randomized:
import tensorflow_datasets as tfds
ds = tfds.builder('imagenet2012').as_dataset(split='train[99%:]')
from collections import Counter
c = Counter(int(e['label']) for e in ds)
>>> len(c)
999
>>> max(c.values())
27
>>> min(c.values())
3
Naively trying to do the same with torchvision.datasets
prevented the model from learning the last few classes and resulted in near-random performance
on the minival: the model only learned the class that happened to stride across the first 99% and the last 1%.
Instead of randomly selecting 99% of the training data or copying the tfds 99% slice, I just fell back to training on 100% of the training data. ImageNet-1k
has 1281167 training images, so 1024 batch size results in 1281167 // 1024 = 1251
steps if we drop the last odd lot. big_vision however doesn't train the
model epoch by epoch: Instead, it makes the dataset iterator infinite and trains for the equivalent number of steps. Furthermore, it round()
the number of
steps instead of dropping the last.
The 90-epoch equivalent therefore would be round(1281167 / 1024 * 90) = 112603
steps and mup-vit
main
follows this practice.
big_vision warms up from 0 learning rate
but torch.optim.lr_scheduler.LinearLR()
disallows starting from 0 learning rate.
I implemented warming up from 0 learning rate with torch.optim.lr_scheduler.LambdaLR()
instead.
In big_vision config.wd
is only scaled by the global LR scheduling, but for torch.optim.AdamW()
"weight_decay
" is first multiplied by the LR.
The correct equivalent value for weight_decay
is therefore 0.1
to match config.lr = 0.001
and config.wd = 0.0001
.
The simplified ViT described in Better plain ViT baselines for ImageNet-1k is not readily available in pytorch. E.g.
vit_pytorch's simple_vit
and simple_flash_attn_vit are rather dated without
taking advantage of torch.nn.MultiheadAttention()
, so I rolled my own.
I have to fix some of the parameter initialization, however:
torch.nn.MultiheadAttention()
comes with its own issues. When QKV are of the same dimension, their projection matrices are combined intoself.in_proj_weight
whose initial values are set withxavier_uniform_()
. Likely unintentionally, this means that the values are sampled from uniform distribution U(−a,a) where a = sqrt(3 / (2 * hidden_dim)) instead of sqrt(3 / hidden_dim). Furthermore, the output projection is initialized asNonDynamicallyQuantizableLinear()
whose initial values are sampled from U(-sqrt(k), sqrt(k)), k = 1 / hidden_dim. Both are therefore re-initialized with U(−a,a) where a = sqrt(3 / hidden_dim)1 to conform with thejax.nn.initializers.xavier_uniform()
used by the reference ViT from big_vision.- pytorch's own
nn.init.trunc_normal_()
doesn't take the effect of truncation on stddev into account, so I used the magic factor from the JAX repo to re-initialize the patchifyingnn.Conv2d
.
After 1 and 2 all of the summary statistics of the model parameters match that of the reference implementation at initialization.
Torchvision transforms of v2.RandAugment() default to zero padding whereas big_vision randaug()
uses RGB values (128, 128, 128) as the replacement value. In both cases I have specified the latter to conform to the reference implementation.
Model trained with all of the above for 90 epoches reached 76.91% top-1 validation set accuracy, but the loss curve and the gradient L2 norm clearly show that it deviates from the reference:
It turned out that RandAugment(num_ops=2, magnitude=10)
means very different things in torchvision vs. big_vision. I created the following 224 × 224 black & white calibration grid consists of 56 × 56 black & white squares:
and applied both versions of RandAugment(2, 10)
100000 times to gather the stats. All of the resulting pixels remain colorless
(i.e. for RGB values (r, g, b) r == g == b remains true) so we can sort them from black to white into a spectrum. For the following 2000 × 200 spectra, pixels are sorted top-down, left-right, and each pixel represents 224 * 224 * 100000 / (2000 * 200) = 112 * 112 pixels of the aggregated output, i.e. 1/4 of one output image. In case one batch of 12544 pixels happens to be of different values, I took the average. Here is the spectrum of torchvision RandAugment(2, 10)
:
Here is the spectrum of torchvision RandAugment(2, 10, fill=[128] * 3)
. We can see that it just shifts the zero-padding part of the black into the (128, 128, 128) gray:
And here is the spectrum of big_vision randaug(2, 10)
:
Digging into the codebase, we can see that while torchvision's v2.RandAugment()
sticks with the original 14-transform lineup of RandAugment: Practical automated data augmentation with a reduced search space, big_vision's own randaug()
omits the Identity
no-op and adds 3 new transforms Invert
, SolarizeAdd
, and Cutout
, along with other subtler discrepancies (e.g. Sharpness
is considered "signed" in torchvision so half of the time the transform blurs the image instead, while in big_vision it always sharpens the image). What I did then is to subclass torchvision's v2.RandAugment()
, remove & add transforms accordingly, and use a variety of calibration grids to make sure that they are within ±1 of the RGB values given by the big_vision's counterpart. The sole exception is Contrast
: more on that later. Even with that exception, the near-replication of big_vision's randaug(2, 10)
results in near-identical spectrum:
Training with the near-replication of big_vision randaug(2, 10)
for 90 epoches reached 77.27% top-1 validation set accuracy and the gradient L2 norm looks the same, but the loss curve still differs:
It turned out that besides the default min scale (8% vs. 5%), the "Inception crop" implemented as torchvision v2.RandomResizedCrop()
is not the same as calling tf.slice()
with the bbox returned by tf.image.sample_distorted_bounding_box()
:
- They both rejection-sample the crop, but
v2.RandomResizedCrop()
is hardcoded to try 10 times whiletf.image.sample_distorted_bounding_box()
defaults to 100 attempts. v2.RandomResizedCrop()
samples the aspect ratio uniformly in log space,tf.image.sample_distorted_bounding_box()
samples uniformly in linear space.v2.RandomResizedCrop()
samples the area cropped uniformly whiletf.image.sample_distorted_bounding_box()
samples the crop height uniformly given the aspect ratio and area range.- If all attempts fail,
v2.RandomResizedCrop()
at least crops the image to make sure that the aspect ratio falls within range before resizing.tf.image.sample_distorted_bounding_box()
just returns the whole image (to be resized).
We can verify this by taking stats of the crop size given the same image. Here is the density plot of (h, w) returned by v2.RandomResizedCrop.get_params(..., scale=(0.05, 1.0), ratio=(3/4, 4/3))
, given an image of (height, width) = (256, 512), N = 10000000:
I got almost 14340 crop failures resulting in a bright pixel at the bottom right, but otherwise the density is roughly uniform. In comparison, here is what tf.image.sample_distorted_bounding_box(..., area_range=[0.05, 1])
returns:
While cropping never failed, we can see clearly that it's oversampling smaller crop areas, as if there were light shining from top-left (notebook). The last discrepancy goes away after re-implementing tf.image.sample_distorted_bounding_box()
's sampling logic:
This true reproduction model reached 76.94% top-1 validation set accuracy after 90 epoches. Now let's turn our attention to big_vision itself and double-check the effects of its bugs, inconsistencies, and unusual features.
big_vision
grad_accum_wandb
branch
I first bolted on wandb
logging and revived utils.accumulate_gradient()
to run 1024 batch size on my GeForce RTX 3080 Laptop GPU. TensorBook is unable to handle shuffle_buffer_size = 250_000
so I shrank it to 150_000
. Finally, I fell back to training on 100% of the training data to converge to what I had to do with pytorch. This resulted in 76.74% top-1 validation set accuracy big-vision-repo-attempt
referenced above and consistent with the reported 76.7% top-1 validation set accuracy.
It turned out that one of big_vision's randaug()
transforms, contrast()
, is broken. In short, what meant to calculate the average grayscale of the image
# Compute the grayscale histogram, then compute the mean pixel value,
# and create a constant image size of that value. Use that as the
# blending degenerate target of the original image.
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
is instead calculating image_area / 256, so in our case of 224 × 224 image, mean is always 196. What it should do is the following:
# Compute the grayscale histogram, then compute the mean pixel value,
# and create a constant image size of that value. Use that as the
# blending degenerate target of the original image.
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
mean = tf.reduce_sum(
tf.cast(hist, tf.float32) * tf.linspace(0., 255., 256)) / float(image_height * image_width)
We can visualize this bug by using the following calibration grid as the input:
and compare the output given by the broken contrast()
:
vs. the output after the fix:
Some CV people are aware of this bug (1, 2) but AFAIK it wasn't documented anywhere in the public. As an aside, solarize()
transform has its own integer overflow bug but just happens to have no effect when magnitude=_MAX_LEVEL
here.
decode_jpeg_and_inception_crop()
used by the training data pipeline defaults to bilinear interpolation without anti-aliasing for resizing, but resize_small()
used by the validation data pipeline defaults to area interpolation that "always anti-aliases". Furthermore, torchvision doesn't support resizing with area interpolation. For consistency, I changed both to bilinear interpolation with anti-aliasing.
tf.io.decode_jpeg()
by default lets the system decide the JPEG decompression algorithm. Specifying dct_method="INTEGER_ACCURATE"
makes it behave like the PIL/cv2/PyTorch counterpart (see also the last few cells of RandAugmentCalibration.ipynb). This option is exposed as decode(precise=True)
in big_vision but is left unexposed for decode_jpeg_and_inception_crop()
, so I added the precise
argument to the latter.
Changing all of the above seems to have no apparent effect on the model, however (76.87% top-1 validation set accuracy).
optax.scale_by_adam()
supports the unusual option of using a different dtype for the 1st order accumulator, mu_dtype
and the reference implementation uses bfloat16
instead of float32
like the rest of the model. Changing it back to float32
, however, still has no apparent effect (76.77% top-1 validation set accuracy).
Finally, back to shuffle_buffer_size
. Unlike torch.utils.data.DataLoader(shuffle=True)
which always fully shuffles by indices, tf.data.Dataset.shuffle(buffer_size)
needs to load buffer_size
's worth of training examples into the main memory and fully shuffles iff buffer_size=dataset.cardinality()
. To test whether incomplete shuffle so far has hurt performance, I launched a 8x A100-SXM4-40GB instance on Lambda and trained a big_vision model on it with all of the above and config.input.shuffle_buffer_size = 1281167
, size of the ImageNet-1k training set. It still has no apparent effect (76.85% top-1 validation set accuracy).
As a by-product, this also proves that big_vision gradient accumulation and multi-GPU training are fully equivalent.
This is the true end of reproducing the Better plain ViT baselines for ImageNet-1k in pytorch. There is no if/but, no mystery left. It's rather ironic that after checking and fixing discrepancies for months, fixing the last discrepancy turned out to be a step-down (77.27% vs. 76.7%-76.94%) in terms of model performance. I have therefore added --torchvision-inception-crop
as an option to switch back to torchvision's Inception crop.
Postscript: Metrics of the models aside, in terms of training walltime, modern (2.2+) PyTorch with compile()
and JAX are nearly identical on the same GPU. The tiny difference may well be fully-explained by the overhead of transposing from channels-last to channels-first and converting from tf.Tensor
to torch.Tensor
. As for hardware comparison, here are the walltimes reported:
Hardware | Walltime |
---|---|
TPUv3-8 node | 6h30 (99%) |
8x A100-SXM4-40GB | 5h41m |
RTX 3080 Laptop | 5d19h32m |
8x A100-SXM4-40GB is comparable but faster than a TPUv3-8 node. RTX 3080 Laptop is unsurprisingly out of the league: 1 day on it is about the same as 1 hour on the other two.
1 See pytorch/pytorch#57109 (comment) for the origin of this discrepancy.