diff --git a/requirements-training.txt b/requirements-training.txt index 52a5c6f80..8484fb40b 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -1,4 +1,4 @@ -torch>=1.9.0 +torch>=1.10.0 torchvision webdataset>=0.2.5 regex diff --git a/requirements.txt b/requirements.txt index 3ff4f5d0a..80903f650 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=1.9.0 +torch>=1.10.0 torchvision regex ftfy diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 272b2cc06..eb5502b64 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from .transformer import ( - LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, @@ -58,9 +57,7 @@ def _build_text_decoder_tower( ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = ( - LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - ) + norm_layer = LayerNorm decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 469d7f5a9..0376471ba 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -18,7 +18,7 @@ from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ +from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ text_global_pool from .utils import to_2tuple @@ -139,7 +139,7 @@ def _build_vision_tower( ) else: vision_heads = vision_cfg.width // vision_cfg.head_width - norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + norm_layer = LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: @@ -190,7 +190,7 @@ def _build_text_tower( ) else: act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + norm_layer = LayerNorm if text_cfg.norm_kwargs: norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) if text_cfg.act_kwargs is not None: diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6d4e604d8..bc451059d 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -13,7 +13,10 @@ class LayerNormFp32(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back). + + Deprecated: pytorch 1.10+ always performs LayerNorm in fp32. Retained for checkpoint compatibility. + """ def forward(self, x: torch.Tensor): orig_type = x.dtype