diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index 79bf8c5..60032a9 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -127,7 +127,7 @@ def __init__( ) def forward(self, img, labels, temperature = None, alpha = None, **kwargs): - b, *_ = img.shape + alpha = alpha if exists(alpha) else self.alpha T = temperature if exists(temperature) else self.temperature