diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index f9a18bc29..d514dbb20 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -122,6 +122,8 @@ def __init__( self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.pad_id = pad_id + self.context_length = multimodal_cfg.context_length + @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True): self.visual.set_grad_checkpointing(enable)