Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workload variants for ImageNet ViT #599

Merged
merged 8 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def posemb_sincos_2d(h: int,
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
use_glu: bool = False
dropout_rate: float = 0.0

@nn.compact
Expand All @@ -47,6 +48,11 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
d = x.shape[2]
x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
x = nn.gelu(x)

if self.use_glu:
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y

x = nn.Dropout(rate=self.dropout_rate)(x, train)
x = nn.Dense(d, **inits)(x)
return x
Expand All @@ -56,26 +62,53 @@ class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
use_glu: bool = False
use_post_layer_norm: bool = False
dropout_rate: float = 0.0

@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y

y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y

y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
else:
y = x
y = nn.SelfAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)

y = x
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
y = nn.Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)

return x


Expand All @@ -85,6 +118,8 @@ class Encoder(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
dropout_rate: float = 0.0
use_glu: bool = False
use_post_layer_norm: bool = False

@nn.compact
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
Expand All @@ -94,9 +129,36 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
name=f'encoderblock_{lyr}',
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
dropout_rate=self.dropout_rate)
x = block(x, train)
return nn.LayerNorm(name='encoder_layernorm')(x)
if not self.use_post_layer_norm:
return nn.LayerNorm(name='encoder_layernorm')(x)
else:
return x


class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12

@nn.compact
def __call__(self, x):
n, _, d = x.shape
probe = self.param('probe',
nn.initializers.xavier_uniform(), (1, 1, d),
x.dtype)
probe = jnp.tile(probe, [n, 1, 1])

x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform())(probe, x)

y = nn.LayerNorm()(x)
x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
return x[:, 0]


class ViT(nn.Module):
Expand All @@ -112,6 +174,9 @@ class ViT(nn.Module):
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
use_post_layer_norm: bool = False
use_map: bool = False

def get_posemb(self,
seqshape: tuple,
Expand Down Expand Up @@ -145,11 +210,16 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
depth=self.depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
dropout_rate=dropout_rate,
name='Transformer')(
x, train=not train)

x = jnp.mean(x, axis=1)
if self.use_map:
x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
else:
x = jnp.mean(x, axis=1)

if self.rep_size:
rep_size = self.width if self.rep_size is True else self.rep_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def init_model_fn(
self._model = models.ViT(
dropout_rate=dropout_rate,
num_classes=self._num_classes,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
use_map=self.use_map,
**decode_variant('S/16'))
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
Expand Down Expand Up @@ -83,3 +86,24 @@ def _eval_model_on_split(self,
rng,
data_dir,
global_step)


class ImagenetVitGluWorkload(ImagenetVitWorkload):

@property
def use_glu(self) -> bool:
return True


class ImagenetVitPostLNWorkload(ImagenetVitWorkload):

@property
def use_post_layer_norm(self) -> bool:
return True


class ImagenetVitMapWorkload(ImagenetVitWorkload):

@property
def use_map(self) -> bool:
return True
Loading
Loading