diff --git a/src/open_clip/hf_model.py b/src/open_clip/hf_model.py index fbccc8127..02c428aab 100644 --- a/src/open_clip/hf_model.py +++ b/src/open_clip/hf_model.py @@ -105,6 +105,7 @@ def __init__( raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models") if config is None: self.config = AutoConfig.from_pretrained(model_name_or_path) + self.context_length = self.config.max_length create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else ( AutoModel.from_config, self.config) # TODO: do all model configs have this attribute? PretrainedConfig does so yes?? @@ -118,7 +119,7 @@ def __init__( self.transformer = AutoModel.from_config(config) if pooler_type is None: # get default arch pooler pooler_type = (arch_dict[self.config.model_type]["pooler"]) - + self.pooler = _POOLERS[pooler_type]() d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"]) @@ -134,6 +135,14 @@ def __init__( nn.Linear(hidden_size, output_dim, bias=False), ) + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + def forward(self, x: TensorType): attn_mask = (x != self.config.pad_token_id).long() out = self.transformer(input_ids=x, attention_mask=attn_mask) @@ -142,11 +151,11 @@ def forward(self, x: TensorType): seq_len = out.last_hidden_state.shape[1] tokens = ( - out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] - if type(self.pooler) == ClsPooler + out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] + if type(self.pooler) == ClsPooler else out.last_hidden_state ) - + if self.output_tokens: return projected, tokens return projected