Skip to content

Commit

Permalink
Default Transformer MHA to batch_first, enables fastpath for vit, sti…
Browse files Browse the repository at this point in the history
…ll need mask changes for text
  • Loading branch information
rwightman committed Jul 3, 2024
1 parent 1598cd1 commit edeca45
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
15 changes: 11 additions & 4 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,12 @@ def __init__(
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
batch_first: bool = True,
):
super().__init__()

self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
Expand Down Expand Up @@ -283,7 +284,8 @@ def __init__(

self.ln_1 = norm_layer(d_model)
self.attn = Attention(
d_model, n_head,
d_model,
n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
batch_first=batch_first,
Expand Down Expand Up @@ -324,10 +326,12 @@ def __init__(
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first
self.grad_checkpointing = False

self.resblocks = nn.ModuleList([
Expand All @@ -338,6 +342,7 @@ def __init__(
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
for _ in range(layers)
])
Expand All @@ -348,14 +353,16 @@ def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype

def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x.transpose(0, 1).contiguous() # NLD -> LND
if not self.batch_first:
x = x.transpose(0, 1).contiguous() # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
x = x.transpose(0, 1) # LND -> NLD
if not self.batch_first:
x = x.transpose(0, 1) # LND -> NLD
return x


Expand Down
2 changes: 1 addition & 1 deletion src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
cumulative_loss = 0.0
cumulative_gen_loss = 0.0
all_image_features, all_text_features = [], []
with torch.no_grad():
with torch.inference_mode():
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
Expand Down
2 changes: 1 addition & 1 deletion src/training/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run(model, classifier, dataloader, args):
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)

with torch.no_grad():
with torch.inference_mode():
top1, top5, n = 0., 0., 0.
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
images = images.to(device=args.device, dtype=input_dtype)
Expand Down

0 comments on commit edeca45

Please sign in to comment.