From 711f6308fd02e23fae0d27514e08b06bb1fd5059 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 11:33:40 -0500 Subject: [PATCH 1/5] minor --- .../imagenet_jax/randaugment.py | 2 +- .../imagenet_vit/imagenet_jax/models.py | 103 +++++++++-- .../imagenet_vit/imagenet_jax/workload.py | 28 ++- .../imagenet_vit/imagenet_pytorch/models.py | 121 ++++++++++--- .../imagenet_vit/imagenet_pytorch/workload.py | 28 ++- .../workloads/imagenet_vit/workload.py | 15 ++ tests/modeldiffs/imagenet_vit/compare.py | 67 ++++++- tests/modeldiffs/imagenet_vit/compare_glu.py | 163 ++++++++++++++++++ .../imagenet_vit/compare_post_ln.py | 163 ++++++++++++++++++ 9 files changed, 644 insertions(+), 46 deletions(-) create mode 100644 tests/modeldiffs/imagenet_vit/compare_glu.py create mode 100644 tests/modeldiffs/imagenet_vit/compare_post_ln.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..8fa1c0789 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,7 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image +# from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index ab5d1839e..4a97ee661 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. + use_glu: bool = False dropout_rate: float = 0.0 @nn.compact @@ -47,6 +48,13 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: d = x.shape[2] x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x) x = nn.gelu(x) + + if self.use_glu: + y = nn.Dense( + self.mlp_dim, + **inits)(x) + x = x * y + x = nn.Dropout(rate=self.dropout_rate)(x, train) x = nn.Dense(d, **inits)(x) return x @@ -56,26 +64,47 @@ class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 + use_glu: bool = False + use_post_layer_norm: bool = False dropout_rate: float = 0.0 @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + if not self.use_post_layer_norm: + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + else: + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + x) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = MlpBlock( + mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, + name='MlpBlock_3')(x, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) + return x @@ -85,6 +114,8 @@ class Encoder(nn.Module): mlp_dim: Optional[int] = None # Defaults to 4x input dim. num_heads: int = 12 dropout_rate: float = 0.0 + use_glu: bool = False + use_post_layer_norm: bool = False @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: @@ -94,9 +125,35 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: name=f'encoderblock_{lyr}', mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=self.dropout_rate) x = block(x, train) - return nn.LayerNorm(name='encoder_layernorm')(x) + if not self.use_post_layer_norm: + return nn.LayerNorm(name='encoder_layernorm')(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim + num_heads: int = 12 + @nn.compact + def __call__(self, x): + n, _, d = x.shape + probe = self.param('probe', + nn.initializers.xavier_uniform(), + (1, 1, d), x.dtype) + probe = jnp.tile(probe, [n, 1, 1]) + + x = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform())(probe, x) + + y = nn.LayerNorm()(x) + x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) + return x[:, 0] class ViT(nn.Module): @@ -112,6 +169,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, def get_posemb(self, seqshape: tuple, @@ -145,11 +205,18 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate, name='Transformer')( x, train=not train) - x = jnp.mean(x, axis=1) + if self.use_map: + x = MAPHead(num_heads=self.num_heads, + mlp_dim=self.mlp_dim + )(x) + else: + x = jnp.mean(x, axis=1) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 3f3af0564..22fcde66a 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,11 +32,16 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + head_zeroinit: bool = True) -> spec.ModelInitState: del aux_dropout_rate self._model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + head_zeroinit=head_zeroinit, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) @@ -83,3 +88,24 @@ def _eval_model_on_split(self, rng, data_dir, global_step) + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetViTPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetViTMapLNWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 55a8e370d..053b0ec76 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -39,18 +39,26 @@ def __init__( self, width: int, mlp_dim: Optional[int] = None, # Defaults to 4x input dim. + use_glu: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim or 4 * width + self.use_glu = use_glu self.dropout_rate = dropout_rate - self.net = nn.Sequential( - nn.Linear(self.width, self.mlp_dim), - nn.GELU(), - nn.Dropout(self.dropout_rate), - nn.Linear(self.mlp_dim, self.width)) + self.linear1 = nn.Linear(self.width, self.mlp_dim) + self.act_fnc = nn.GELU(approximate='tanh') + self.dropout = nn.Dropout(self.dropout_rate) + + if self.use_glu: + self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim) + else: + self.glu_linear = None + + self.linear2 = nn.Linear(self.mlp_dim, self.width) + self.reset_parameters() def reset_parameters(self) -> None: @@ -61,7 +69,16 @@ def reset_parameters(self) -> None: module.bias.data.normal_(std=1e-6) def forward(self, x: spec.Tensor) -> spec.Tensor: - return self.net(x) + x = self.linear1(x) + x = self.act_fnc(x) + + if self.use_glu: + y = self.glu_linear(x) + x = x * y + + x = self.dropout(x) + x = self.linear2(x) + return x class SelfAttention(nn.Module): @@ -129,29 +146,44 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6) self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(self.width, self.mlp_dim, dropout_rate) + self.mlp3 = MlpBlock(width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: - y = self.layer_norm0(x) - y = self.self_attention1(y) - y = self.dropout(y) - x = x + y - - y = self.layer_norm2(x) - y = self.mlp3(y) - y = self.dropout(y) - x = x + y + if not self.use_post_layer_norm: + y = self.layer_norm0(x) + y = self.self_attention1(y) + y = self.dropout(y) + x = x + y + + y = self.layer_norm2(x) + y = self.mlp3(y) + y = self.dropout(y) + x = x + y + else: + y = self.self_attention1(x) + y = self.dropout(y) + x = x + y + x = self.layer_norm0(x) + + y = self.mlp3(x) + y = self.dropout(y) + x = x + y + x = self.layer_norm2(x) return x @@ -163,6 +195,8 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12, + use_glu: bool = False, + use_post_layer_norm: bool = False, dropout_rate: float = 0.0) -> None: super().__init__() @@ -170,18 +204,53 @@ def __init__(self, self.width = width self.mlp_dim = mlp_dim self.num_heads = num_heads + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, dropout_rate) + Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, self.use_glu, self.use_post_layer_norm, dropout_rate) for _ in range(depth) ]) - self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + + if not self.use_post_layer_norm: + self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6) + else: + self.encoder_norm = None def forward(self, x: spec.Tensor) -> spec.Tensor: # Input Encoder. for block in self.net: x = block(x) - return self.encoder_norm(x) + if not self.use_post_layer_norm: + return self.encoder_norm(x) + else: + return x + + +class MAPHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12): + super().__init__() + self.width = width + self.mlp_dim = mlp_dim + self.num_heads = num_heads + + self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) + nn.init.xavier_uniform_(self.probe.data) + + self.mha = nn.MultiheadAttention(embed_dim=self.width, num_heads=self.num_heads) + self.layer_nrom = nn.LayerNorm(self.width, eps=1e-6) + self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) + + def forward(self, x): + n, _, _ = x.shape + probe = torch.tile(self.probe, [n, 1, 1]) + + x = self.mha(probe, x) + y = self.layer_nrom(x) + x = x + self.mlp(y) + return x[:, 0] class ViT(nn.Module): @@ -202,6 +271,9 @@ def __init__( rep_size: Union[int, bool] = True, dropout_rate: Optional[float] = 0.0, head_zeroinit: bool = True, + use_glu: bool = False, + use_post_layer_norm: bool = False, + use_map: bool = False, dtype: Any = torch.float32) -> None: super().__init__() if dropout_rate is None: @@ -215,6 +287,9 @@ def __init__( self.num_heads = num_heads self.rep_size = rep_size self.head_zeroinit = head_zeroinit + self.use_glu = use_glu + self.use_post_layer_norm = use_post_layer_norm + self.use_map = use_map self.dtype = dtype if self.rep_size: @@ -234,6 +309,8 @@ def __init__( width=self.width, mlp_dim=self.mlp_dim, num_heads=self.num_heads, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, dropout_rate=dropout_rate) if self.num_classes: @@ -270,7 +347,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.dropout(x) x = self.encoder(x) - x = torch.mean(x, dim=1) + + if self.use_map: + pass + else: + x = torch.mean(x, dim=1) if self.rep_size: x = torch.tanh(self.pre_logits(x)) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 08a62ede6..9e8af3a68 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -28,12 +28,17 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None, + head_zeroinit: bool = True) -> spec.ModelInitState: del aux_dropout_rate torch.random.manual_seed(rng[0]) model = models.ViT( dropout_rate=dropout_rate, num_classes=self._num_classes, + use_glu=self.use_glu, + use_post_layer_norm=self.use_post_layer_norm, + use_map=self.use_map, + head_zeroinit=head_zeroinit, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -77,3 +82,24 @@ def model_fn( logits_batch = model(augmented_and_preprocessed_input_batch['inputs']) return logits_batch, None + + +class ImagenetVitGluWorkload(ImagenetVitWorkload): + + @property + def use_glu(self) -> bool: + return True + + +class ImagenetViTPostLNWorkload(ImagenetVitWorkload): + + @property + def use_post_layer_norm(self) -> bool: + return True + + +class ImagenetViTMapLNWorkload(ImagenetVitWorkload): + + @property + def use_map(self) -> bool: + return True diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/workload.py index 61d3acfd3..ed0118ca0 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/workload.py @@ -60,6 +60,21 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: return 1 - 0.3481 # 0.6519 + @property + def use_post_layer_norm(self) -> bool: + """Whether to use layer normalization after the residual branch.""" + return False + + @property + def use_map(self) -> bool: + """Whether to use multihead attention pooling.""" + return False + + @property + def use_glu(self) -> bool: + """Whether to use GLU in the MLPBlock.""" + return False + @property def eval_batch_size(self) -> int: return 2048 diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 1022b5b54..3e8b9dcb1 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -3,20 +3,75 @@ # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' -import jax -import torch - from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload -from tests.modeldiffs.diff import out_diff +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) def key_transform(k): if 'Conv' in k[0]: - k = ('embedding', *k[1:]) + k = ('conv_patch_extract', *k[1:]) elif k[0] == 'Linear_0': k = ('pre_logits', *k[1:]) elif k[0] == 'Linear_1': @@ -35,6 +90,8 @@ def key_transform(k): continue if 'CustomBatchNorm' in i: continue + if 'GLU' in i: + pass if 'Linear' in i: if attention: i = { diff --git a/tests/modeldiffs/imagenet_vit/compare_glu.py b/tests/modeldiffs/imagenet_vit/compare_glu.py new file mode 100644 index 000000000..a6f01f971 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/compare_glu.py @@ -0,0 +1,163 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + + +def key_transform(k): + if 'Conv' in k[0]: + k = ('conv_patch_extract', *k[1:]) + elif k[0] == 'Linear_0': + k = ('pre_logits', *k[1:]) + elif k[0] == 'Linear_1': + k = ('head', *k[1:]) + + new_key = [] + bn = False + attention = False + ln = False + enc_block = False + for idx, i in enumerate(k): + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + attention = attention or 'SelfAttention' in i + if 'ModuleList' in i or 'Sequential' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'GLU' in i: + pass + if 'Linear' in i: + if attention: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] + else: + i = i.replace('Linear', 'Dense') + elif 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + elif 'Encoder1DBlock' in i: + i = i.replace('Encoder1DBlock', 'encoderblock') + enc_block = True + elif 'Encoder' in i: + i = 'Transformer' + elif enc_block and 'SelfAttention' in i: + i = 'MultiHeadDotProductAttention_1' + elif enc_block and i == 'LayerNorm_1': + i = 'LayerNorm_2' + elif enc_block and 'MlpBlock' in i: + i = 'MlpBlock_3' + elif idx == 1 and i == 'LayerNorm_0': + i = 'encoder_layernorm' + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/compare_post_ln.py b/tests/modeldiffs/imagenet_vit/compare_post_ln.py new file mode 100644 index 000000000..e27d77482 --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/compare_post_ln.py @@ -0,0 +1,163 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload +from flax import jax_utils +import jax +import numpy as np +import torch + +from tests.modeldiffs.torch2jax_utils import Torch2Jax +from tests.modeldiffs.torch2jax_utils import value_transform + + +#pylint: disable=dangerous-default-value +def torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform=None, + sd_transform=None, + init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): + jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), + **init_kwargs) + pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) + jax_params = jax_utils.unreplicate(jax_params).unfreeze() + if model_state is not None: + model_state = jax_utils.unreplicate(model_state) + + if isinstance( + pytorch_model, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): + pytorch_model = pytorch_model.module + t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) + if key_transform is not None: + t2j.key_transform(key_transform) + if sd_transform is not None: + t2j.sd_transform(sd_transform) + t2j.value_transform(value_transform) + t2j.diff() + t2j.update_jax_model() + return jax_params, model_state, pytorch_model + + +def out_diff(jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None): + jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, + pytorch_workload, + key_transform, + sd_transform) + out_p, _ = pytorch_workload.model_fn(params=pytorch_model, + **pytorch_model_kwargs) + out_j, _ = jax_workload.model_fn(params=jax_params, + model_state=model_state, + **jax_model_kwargs) + if out_transform is not None: + out_p = out_transform(out_p) + out_j = out_transform(out_j) + + print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) + print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) + + +def key_transform(k): + if 'Conv' in k[0]: + k = ('conv_patch_extract', *k[1:]) + elif k[0] == 'Linear_0': + k = ('pre_logits', *k[1:]) + elif k[0] == 'Linear_1': + k = ('head', *k[1:]) + + new_key = [] + bn = False + attention = False + ln = False + enc_block = False + for idx, i in enumerate(k): + bn = bn or 'BatchNorm' in i + ln = ln or 'LayerNorm' in i + attention = attention or 'SelfAttention' in i + if 'ModuleList' in i or 'Sequential' in i: + continue + if 'CustomBatchNorm' in i: + continue + if 'GLU' in i: + pass + if 'Linear' in i: + if attention: + i = { + 'Linear_0': 'query', + 'Linear_1': 'key', + 'Linear_2': 'value', + 'Linear_3': 'out', + }[i] + else: + i = i.replace('Linear', 'Dense') + elif 'Conv2d' in i: + i = i.replace('Conv2d', 'Conv') + elif 'Encoder1DBlock' in i: + i = i.replace('Encoder1DBlock', 'encoderblock') + enc_block = True + elif 'Encoder' in i: + i = 'Transformer' + elif enc_block and 'SelfAttention' in i: + i = 'MultiHeadDotProductAttention_1' + elif enc_block and i == 'LayerNorm_1': + i = 'LayerNorm_2' + elif enc_block and 'MlpBlock' in i: + i = 'MlpBlock_3' + elif idx == 1 and i == 'LayerNorm_0': + i = 'encoder_layernorm' + elif 'weight' in i: + if bn or ln: + i = i.replace('weight', 'scale') + else: + i = i.replace('weight', 'kernel') + new_key.append(i) + return tuple(new_key) + + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) From c8a9e728486ea6faffd11ac8df4e90876377b0a7 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:20:40 -0500 Subject: [PATCH 2/5] Clean up model diff --- .../imagenet_jax/randaugment.py | 1 + .../imagenet_vit/imagenet_jax/models.py | 6 +- .../imagenet_vit/imagenet_jax/workload.py | 4 +- .../imagenet_vit/imagenet_pytorch/models.py | 51 ++++-- .../imagenet_vit/imagenet_pytorch/workload.py | 6 +- tests/modeldiffs/imagenet_vit/compare.py | 66 +------ tests/modeldiffs/imagenet_vit/compare_glu.py | 163 ------------------ .../imagenet_vit/compare_post_ln.py | 163 ------------------ tests/modeldiffs/imagenet_vit/glu_compare.py | 52 ++++++ .../imagenet_vit/post_ln_compare.py | 52 ++++++ 10 files changed, 154 insertions(+), 410 deletions(-) delete mode 100644 tests/modeldiffs/imagenet_vit/compare_glu.py delete mode 100644 tests/modeldiffs/imagenet_vit/compare_post_ln.py create mode 100644 tests/modeldiffs/imagenet_vit/glu_compare.py create mode 100644 tests/modeldiffs/imagenet_vit/post_ln_compare.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 8fa1c0789..caa77ae35 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + # from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 4a97ee661..c88132621 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -88,19 +88,21 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y else: + y = x y = nn.SelfAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, name='MultiHeadDotProductAttention_1')( - x) + y) y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_0')(x) + y = x y = MlpBlock( mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(x, train) + name='MlpBlock_3')(y, train) y = nn.Dropout(rate=self.dropout_rate)(y, train) x = x + y x = nn.LayerNorm(name='LayerNorm_2')(x) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 22fcde66a..1acd58bcd 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -32,8 +32,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - head_zeroinit: bool = True) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate self._model = models.ViT( dropout_rate=dropout_rate, @@ -41,7 +40,6 @@ def init_model_fn( use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - head_zeroinit=head_zeroinit, **decode_variant('S/16')) params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py index 053b0ec76..469716d59 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -1,8 +1,8 @@ """PyTorch implementation of refactored and simplified ViT. Adapted from: -https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit. -https://github.com/lucidrains/vit-pytorch. +https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit +and https://github.com/lucidrains/vit-pytorch. """ import math @@ -14,9 +14,12 @@ from algorithmic_efficiency import init_utils from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ + MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: + """Follows the MoCo v3 logic.""" _, width, h, w = patches.shape device = patches.device y, x = torch.meshgrid(torch.arange(h, device=device), @@ -161,7 +164,11 @@ def __init__(self, self.self_attention1 = SelfAttention(self.width, self.num_heads) self.dropout = nn.Dropout(dropout_rate) self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6) - self.mlp3 = MlpBlock(width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=dropout_rate) + self.mlp3 = MlpBlock( + width=self.width, + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=dropout_rate) def forward(self, x: spec.Tensor) -> spec.Tensor: if not self.use_post_layer_norm: @@ -175,12 +182,14 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: y = self.dropout(y) x = x + y else: - y = self.self_attention1(x) + y = x + y = self.self_attention1(y) y = self.dropout(y) x = x + y x = self.layer_norm0(x) - y = self.mlp3(x) + y = x + y = self.mlp3(y) y = self.dropout(y) x = x + y x = self.layer_norm2(x) @@ -208,8 +217,12 @@ def __init__(self, self.use_post_layer_norm = use_post_layer_norm self.net = nn.ModuleList([ - Encoder1DBlock(self.width, self.mlp_dim, self.num_heads, self.use_glu, self.use_post_layer_norm, dropout_rate) - for _ in range(depth) + Encoder1DBlock(self.width, + self.mlp_dim, + self.num_heads, + self.use_glu, + self.use_post_layer_norm, + dropout_rate) for _ in range(depth) ]) if not self.use_post_layer_norm: @@ -230,7 +243,10 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: class MAPHead(nn.Module): """Multihead Attention Pooling.""" - def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12): + def __init__(self, + width: int, + mlp_dim: Optional[int] = None, + num_heads: int = 12): super().__init__() self.width = width self.mlp_dim = mlp_dim @@ -239,16 +255,17 @@ def __init__(self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 1 self.probe = nn.Parameter(torch.zeros((1, 1, self.width))) nn.init.xavier_uniform_(self.probe.data) - self.mha = nn.MultiheadAttention(embed_dim=self.width, num_heads=self.num_heads) - self.layer_nrom = nn.LayerNorm(self.width, eps=1e-6) + self.mha = MultiheadAttention( + self.width, num_heads=self.num_heads, self_attn=False, bias=False) + self.layer_norm = nn.LayerNorm(self.width, eps=1e-6) self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim) - def forward(self, x): + def forward(self, x: spec.Tensor) -> spec.Tensor: n, _, _ = x.shape probe = torch.tile(self.probe, [n, 1, 1]) - x = self.mha(probe, x) - y = self.layer_nrom(x) + x = self.mha(probe, x)[0] + y = self.layer_norm(x) x = x + self.mlp(y) return x[:, 0] @@ -315,6 +332,12 @@ def __init__( if self.num_classes: self.head = nn.Linear(self.width, self.num_classes) + + if self.use_map: + self.map = MAPHead(self.width, self.mlp_dim, self.num_heads) + else: + self.map = None + self.reset_parameters() def reset_parameters(self) -> None: @@ -349,7 +372,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor: x = self.encoder(x) if self.use_map: - pass + x = self.map(x) else: x = torch.mean(x, dim=1) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 9e8af3a68..013bc643f 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -28,8 +28,7 @@ def init_model_fn( self, rng: spec.RandomState, dropout_rate: Optional[float] = None, - aux_dropout_rate: Optional[float] = None, - head_zeroinit: bool = True) -> spec.ModelInitState: + aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: del aux_dropout_rate torch.random.manual_seed(rng[0]) model = models.ViT( @@ -38,7 +37,6 @@ def init_model_fn( use_glu=self.use_glu, use_post_layer_norm=self.use_post_layer_norm, use_map=self.use_map, - head_zeroinit=head_zeroinit, **decode_variant('S/16')) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) @@ -98,7 +96,7 @@ def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapLNWorkload(ImagenetVitWorkload): +class ImagenetViTMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 3e8b9dcb1..39f2651a0 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -1,72 +1,18 @@ import os +from tests.modeldiffs.diff import out_diff + # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' +import jax +import torch + from algorithmic_efficiency import spec from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) def key_transform(k): @@ -90,8 +36,6 @@ def key_transform(k): continue if 'CustomBatchNorm' in i: continue - if 'GLU' in i: - pass if 'Linear' in i: if attention: i = { diff --git a/tests/modeldiffs/imagenet_vit/compare_glu.py b/tests/modeldiffs/imagenet_vit/compare_glu.py deleted file mode 100644 index a6f01f971..000000000 --- a/tests/modeldiffs/imagenet_vit/compare_glu.py +++ /dev/null @@ -1,163 +0,0 @@ -import os - -# Disable GPU access for both jax and pytorch. -os.environ['CUDA_VISIBLE_DEVICES'] = '' - -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetVitGluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetVitGluWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) - - -def key_transform(k): - if 'Conv' in k[0]: - k = ('conv_patch_extract', *k[1:]) - elif k[0] == 'Linear_0': - k = ('pre_logits', *k[1:]) - elif k[0] == 'Linear_1': - k = ('head', *k[1:]) - - new_key = [] - bn = False - attention = False - ln = False - enc_block = False - for idx, i in enumerate(k): - bn = bn or 'BatchNorm' in i - ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i - if 'ModuleList' in i or 'Sequential' in i: - continue - if 'CustomBatchNorm' in i: - continue - if 'GLU' in i: - pass - if 'Linear' in i: - if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] - else: - i = i.replace('Linear', 'Dense') - elif 'Conv2d' in i: - i = i.replace('Conv2d', 'Conv') - elif 'Encoder1DBlock' in i: - i = i.replace('Encoder1DBlock', 'encoderblock') - enc_block = True - elif 'Encoder' in i: - i = 'Transformer' - elif enc_block and 'SelfAttention' in i: - i = 'MultiHeadDotProductAttention_1' - elif enc_block and i == 'LayerNorm_1': - i = 'LayerNorm_2' - elif enc_block and 'MlpBlock' in i: - i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': - i = 'encoder_layernorm' - elif 'weight' in i: - if bn or ln: - i = i.replace('weight', 'scale') - else: - i = i.replace('weight', 'kernel') - new_key.append(i) - return tuple(new_key) - - -sd_transform = None - -if __name__ == '__main__': - # pylint: disable=locally-disabled, not-callable - - jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() - - # Test outputs for identical weights and inputs. - image = torch.randn(2, 3, 224, 224) - - jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} - - pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) - - jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) - - out_diff( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None, - ) diff --git a/tests/modeldiffs/imagenet_vit/compare_post_ln.py b/tests/modeldiffs/imagenet_vit/compare_post_ln.py deleted file mode 100644 index e27d77482..000000000 --- a/tests/modeldiffs/imagenet_vit/compare_post_ln.py +++ /dev/null @@ -1,163 +0,0 @@ -import os - -# Disable GPU access for both jax and pytorch. -os.environ['CUDA_VISIBLE_DEVICES'] = '' - -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ - ImagenetViTPostLNWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ - ImagenetViTPostLNWorkload as PytWorkload -from flax import jax_utils -import jax -import numpy as np -import torch - -from tests.modeldiffs.torch2jax_utils import Torch2Jax -from tests.modeldiffs.torch2jax_utils import value_transform - - -#pylint: disable=dangerous-default-value -def torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform=None, - sd_transform=None, - init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0, head_zeroinit=False)): - jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0), - **init_kwargs) - pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs) - jax_params = jax_utils.unreplicate(jax_params).unfreeze() - if model_state is not None: - model_state = jax_utils.unreplicate(model_state) - - if isinstance( - pytorch_model, - (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): - pytorch_model = pytorch_model.module - t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params) - if key_transform is not None: - t2j.key_transform(key_transform) - if sd_transform is not None: - t2j.sd_transform(sd_transform) - t2j.value_transform(value_transform) - t2j.diff() - t2j.update_jax_model() - return jax_params, model_state, pytorch_model - - -def out_diff(jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None): - jax_params, model_state, pytorch_model = torch2jax_with_zeroinit(jax_workload, - pytorch_workload, - key_transform, - sd_transform) - out_p, _ = pytorch_workload.model_fn(params=pytorch_model, - **pytorch_model_kwargs) - out_j, _ = jax_workload.model_fn(params=jax_params, - model_state=model_state, - **jax_model_kwargs) - if out_transform is not None: - out_p = out_transform(out_p) - out_j = out_transform(out_j) - - print(np.abs(out_p.detach().numpy() - np.array(out_j)).max()) - print(np.abs(out_p.detach().numpy() - np.array(out_j)).min()) - - -def key_transform(k): - if 'Conv' in k[0]: - k = ('conv_patch_extract', *k[1:]) - elif k[0] == 'Linear_0': - k = ('pre_logits', *k[1:]) - elif k[0] == 'Linear_1': - k = ('head', *k[1:]) - - new_key = [] - bn = False - attention = False - ln = False - enc_block = False - for idx, i in enumerate(k): - bn = bn or 'BatchNorm' in i - ln = ln or 'LayerNorm' in i - attention = attention or 'SelfAttention' in i - if 'ModuleList' in i or 'Sequential' in i: - continue - if 'CustomBatchNorm' in i: - continue - if 'GLU' in i: - pass - if 'Linear' in i: - if attention: - i = { - 'Linear_0': 'query', - 'Linear_1': 'key', - 'Linear_2': 'value', - 'Linear_3': 'out', - }[i] - else: - i = i.replace('Linear', 'Dense') - elif 'Conv2d' in i: - i = i.replace('Conv2d', 'Conv') - elif 'Encoder1DBlock' in i: - i = i.replace('Encoder1DBlock', 'encoderblock') - enc_block = True - elif 'Encoder' in i: - i = 'Transformer' - elif enc_block and 'SelfAttention' in i: - i = 'MultiHeadDotProductAttention_1' - elif enc_block and i == 'LayerNorm_1': - i = 'LayerNorm_2' - elif enc_block and 'MlpBlock' in i: - i = 'MlpBlock_3' - elif idx == 1 and i == 'LayerNorm_0': - i = 'encoder_layernorm' - elif 'weight' in i: - if bn or ln: - i = i.replace('weight', 'scale') - else: - i = i.replace('weight', 'kernel') - new_key.append(i) - return tuple(new_key) - - -sd_transform = None - -if __name__ == '__main__': - # pylint: disable=locally-disabled, not-callable - - jax_workload = JaxWorkload() - pytorch_workload = PytWorkload() - - # Test outputs for identical weights and inputs. - image = torch.randn(2, 3, 224, 224) - - jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} - - pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, - model_state=None, - mode=spec.ForwardPassMode.EVAL, - rng=None, - update_batch_norm=False) - - jax_model_kwargs = dict( - augmented_and_preprocessed_input_batch=jax_batch, - mode=spec.ForwardPassMode.EVAL, - rng=jax.random.PRNGKey(0), - update_batch_norm=False) - - out_diff( - jax_workload=jax_workload, - pytorch_workload=pytorch_workload, - jax_model_kwargs=jax_model_kwargs, - pytorch_model_kwargs=pytorch_model_kwargs, - key_transform=key_transform, - sd_transform=None, - ) diff --git a/tests/modeldiffs/imagenet_vit/glu_compare.py b/tests/modeldiffs/imagenet_vit/glu_compare.py new file mode 100644 index 000000000..444f1230a --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/glu_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetVitGluWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetVitGluWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) diff --git a/tests/modeldiffs/imagenet_vit/post_ln_compare.py b/tests/modeldiffs/imagenet_vit/post_ln_compare.py new file mode 100644 index 000000000..8bf0bef7e --- /dev/null +++ b/tests/modeldiffs/imagenet_vit/post_ln_compare.py @@ -0,0 +1,52 @@ +import os + +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.imagenet_vit.compare import key_transform + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ + ImagenetViTPostLNWorkload as JaxWorkload +from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ + ImagenetViTPostLNWorkload as PytWorkload + +sd_transform = None + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PytWorkload() + + # Test outputs for identical weights and inputs. + image = torch.randn(2, 3, 224, 224) + + jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} + pyt_batch = {'inputs': image} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=None, + ) From ecf8220edf11ecde32511f4dbe97888307b2cf86 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:33:38 -0500 Subject: [PATCH 3/5] Add docker image --- .../imagenet_resnet/imagenet_jax/randaugment.py | 3 +-- algorithmic_efficiency/workloads/workloads.py | 12 ++++++++++++ docker/scripts/startup.sh | 3 ++- tests/modeldiffs/imagenet_vit/compare.py | 3 +-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index caa77ae35..5f92b1482 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,8 +8,7 @@ import math import tensorflow as tf - -# from tensorflow_addons import image as contrib_image +from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algorithmic_efficiency/workloads/workloads.py b/algorithmic_efficiency/workloads/workloads.py index 6cc53b7dd..bf444ea36 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algorithmic_efficiency/workloads/workloads.py @@ -56,6 +56,18 @@ 'workload_path': 'imagenet_vit/imagenet', 'workload_class_name': 'ImagenetVitWorkload', }, + 'imagenet_vit_glu': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetVitGluWorkload', + }, + 'imagenet_vit_post_ln': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTPostLNWorkload', + }, + 'imagenet_vit_map': { + 'workload_path': 'imagenet_vit/imagenet', + 'workload_class_name': 'ImagenetViTMapLNWorkload', + }, 'librispeech_conformer': { 'workload_path': 'librispeech_conformer/librispeech', 'workload_class_name': 'LibriSpeechConformerWorkload', diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 3f7458e4b..3b366b71c 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -113,7 +113,8 @@ done VALID_DATASETS=("criteo1tb" "imagenet" "fastmri" "ogbg" "librispeech" \ "wmt" "mnist") VALID_WORKLOADS=("criteo1tb" "imagenet_resnet" "imagenet_resnet_silu" "imagenet_resnet_gelu" \ - "imagenet_resnet_large_bn_init" "imagenet_vit" "fastmri" "ogbg" \ + "imagenet_resnet_large_bn_init" "imagenet_vit" "imagenet_vit_glu" \ + "imagenet_vit_post_ln" "imagenet_vit_map" "fastmri" "ogbg" \ "wmt" "librispeech_deepspeech" "librispeech_conformer" "mnist" \ "criteo1tb_resnet" "criteo1tb_layernorm" "criteo_embed_init" \ "conformer_layernorm" "conformer_attention_temperature" \ diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index 39f2651a0..bf7d6dfa5 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -1,7 +1,5 @@ import os -from tests.modeldiffs.diff import out_diff - # Disable GPU access for both jax and pytorch. os.environ['CUDA_VISIBLE_DEVICES'] = '' @@ -13,6 +11,7 @@ ImagenetVitWorkload as JaxWorkload from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PytWorkload +from tests.modeldiffs.diff import out_diff def key_transform(k): From 290807795fc8a1cf392dd7e94823569d5b651e40 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Fri, 8 Dec 2023 23:40:36 -0500 Subject: [PATCH 4/5] Lint fix --- .../imagenet_vit/imagenet_jax/models.py | 91 ++++++++++--------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index c88132621..32e748ec7 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -50,9 +50,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = nn.gelu(x) if self.use_glu: - y = nn.Dense( - self.mlp_dim, - **inits)(x) + y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y x = nn.Dropout(rate=self.dropout_rate)(x, train) @@ -71,41 +69,45 @@ class Encoder1DBlock(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: - y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - - y = nn.LayerNorm(name='LayerNorm_2')(x) - y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y + y = nn.LayerNorm(name='LayerNorm_0')(x) + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + + y = nn.LayerNorm(name='LayerNorm_2')(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y else: - y = x - y = nn.SelfAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - deterministic=train, - name='MultiHeadDotProductAttention_1')( - y) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_0')(x) - - y = x - y = MlpBlock( - mlp_dim=self.mlp_dim, use_glu=self.use_glu, dropout_rate=self.dropout_rate, - name='MlpBlock_3')(y, train) - y = nn.Dropout(rate=self.dropout_rate)(y, train) - x = x + y - x = nn.LayerNorm(name='LayerNorm_2')(x) + y = x + y = nn.SelfAttention( + num_heads=self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + deterministic=train, + name='MultiHeadDotProductAttention_1')( + y) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_0')(x) + + y = x + y = MlpBlock( + mlp_dim=self.mlp_dim, + use_glu=self.use_glu, + dropout_rate=self.dropout_rate, + name='MlpBlock_3')(y, train) + y = nn.Dropout(rate=self.dropout_rate)(y, train) + x = x + y + x = nn.LayerNorm(name='LayerNorm_2')(x) return x @@ -141,12 +143,13 @@ class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 + @nn.compact def __call__(self, x): n, _, d = x.shape probe = self.param('probe', - nn.initializers.xavier_uniform(), - (1, 1, d), x.dtype) + nn.initializers.xavier_uniform(), (1, 1, d), + x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( @@ -171,9 +174,9 @@ class ViT(nn.Module): dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0. reinit: Optional[Sequence[str]] = None head_zeroinit: bool = True - use_glu: bool = False, - use_post_layer_norm: bool = False, - use_map: bool = False, + use_glu: bool = False + use_post_layer_norm: bool = False + use_map: bool = False def get_posemb(self, seqshape: tuple, @@ -214,9 +217,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor: x, train=not train) if self.use_map: - x = MAPHead(num_heads=self.num_heads, - mlp_dim=self.mlp_dim - )(x) + x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) else: x = jnp.mean(x, axis=1) From 72573f4eea6a81062dd975b8eb83bb440979d7ce Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 3 Jan 2024 20:34:31 -0500 Subject: [PATCH 5/5] Fix names --- .../workloads/imagenet_vit/imagenet_jax/workload.py | 4 ++-- .../workloads/imagenet_vit/imagenet_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 1acd58bcd..4b12247c2 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -95,14 +95,14 @@ def use_glu(self) -> bool: return True -class ImagenetViTPostLNWorkload(ImagenetVitWorkload): +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapLNWorkload(ImagenetVitWorkload): +class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py index 013bc643f..645b795ca 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -89,14 +89,14 @@ def use_glu(self) -> bool: return True -class ImagenetViTPostLNWorkload(ImagenetVitWorkload): +class ImagenetVitPostLNWorkload(ImagenetVitWorkload): @property def use_post_layer_norm(self) -> bool: return True -class ImagenetViTMapWorkload(ImagenetVitWorkload): +class ImagenetVitMapWorkload(ImagenetVitWorkload): @property def use_map(self) -> bool: