diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 0a30e9466..ce5e0d3f7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -586,7 +586,7 @@ def build_attention_mask(self): def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) diff --git a/tests/test_hf_model.py b/tests/test_hf_model.py index 79df2f2cf..f9191f1f4 100644 --- a/tests/test_hf_model.py +++ b/tests/test_hf_model.py @@ -8,8 +8,8 @@ def test_poolers(): bs, sl, d = 2, 10, 5 h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d) - mask = torch.ones(bs, sl, dtype=torch.long) - mask[:2, 6:] = 0 + mask = torch.ones(bs, sl, dtype=torch.bool) + mask[:2, 6:] = False x = BaseModelOutput(h) for name, cls in _POOLERS.items(): pooler = cls()