Skip to content

Commit

Permalink
Attention pooler always batch_first, update MultiModalTransformer to …
Browse files Browse the repository at this point in the history
…default batch_first
  • Loading branch information
rwightman committed Jul 3, 2024
1 parent 36d4046 commit fbcabec
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,20 @@ def __init__(
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)

def forward(self, x: torch.Tensor):
x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
x = self.ln_k(x)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0]
return out.permute(1, 0, 2) # LND -> NLD
return out


class ResidualAttentionBlock(nn.Module):
Expand Down Expand Up @@ -821,6 +821,7 @@ def __init__(
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_dim: int = 512,
batch_first: bool = True,
):

super().__init__(
Expand All @@ -831,6 +832,7 @@ def __init__(
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
self.context_length = context_length
self.cross_attn = nn.ModuleList([
Expand All @@ -842,6 +844,7 @@ def __init__(
act_layer=act_layer,
norm_layer=norm_layer,
is_cross_attention=True,
batch_first=batch_first,
)
for _ in range(layers)
])
Expand Down Expand Up @@ -878,9 +881,10 @@ def build_attention_mask(self):
return mask

def forward(self, image_embs, text_embs):
text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
seq_len = text_embs.shape[0]
seq_len = text_embs.shape[1]
if not self.batch_first:
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND

for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand All @@ -891,13 +895,14 @@ def forward(self, image_embs, text_embs):
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)

x = text_embs.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
if not self.batch_first:
text_embs = text_embs.permute(1, 0, 2) # LND -> NLD

out = self.ln_final(text_embs)
if self.text_projection is not None:
x = x @ self.text_projection
out = out @ self.text_projection

return x
return out

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
Expand Down

0 comments on commit fbcabec

Please sign in to comment.