From 60ded6260059dd6d8e06ae44ecb7d7b612e5235f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 24 Nov 2023 12:10:40 +0200 Subject: [PATCH] Remove more Torch version comparisons Follows up on ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3 --- main.py | 14 +++-- sgm/models/autoencoder.py | 4 +- sgm/modules/attention.py | 66 +++++++++--------------- sgm/modules/autoencoding/temporal_ae.py | 6 +-- sgm/modules/diffusionmodules/model.py | 7 +-- sgm/modules/diffusionmodules/wrappers.py | 12 ++--- 6 files changed, 40 insertions(+), 69 deletions(-) diff --git a/main.py b/main.py index 5e03c1c5..3178a740 100644 --- a/main.py +++ b/main.py @@ -187,13 +187,12 @@ def str2bool(v): default=False, # TODO: later default to True help="log to wandb", ) - if version.parse(torch.__version__) >= version.parse("2.0.0"): - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help="single checkpoint file to resume from", - ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="single checkpoint file to resume from", + ) default_args = default_trainer_args() for key in default_args: parser.add_argument("--" + key, default=default_args[key]) @@ -618,7 +617,6 @@ def init_wandb(save_dir, opt, config, group_name, name_str): # move before model init, in case a torch.compile(...) is called somewhere if opt.enable_tf32: - # pt_version = version.parse(torch.__version__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f"Enabling TF32 for PyTorch {torch.__version__}") diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index 2949b910..6cc86c0d 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn from einops import rearrange -from packaging import version from ..modules.autoencoding.regularizers import AbstractRegularizer from ..modules.ema import LitEma @@ -43,8 +42,7 @@ def __init__( self.model_ema = LitEma(self, decay=ema_decay) logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - if version.parse(torch.__version__) >= version.parse("2.0.0"): - self.automatic_optimization = False + self.automatic_optimization = False # pytorch lightning def apply_ckpt(self, ckpt: Union[None, str, dict]): if ckpt is None: diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index 52a50b7b..8d93e188 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -9,42 +9,31 @@ from packaging import version from torch import nn from torch.utils.checkpoint import checkpoint +from torch.backends.cuda import SDPBackend, sdp_kernel logpy = logging.getLogger(__name__) -if version.parse(torch.__version__) >= version.parse("2.0.0"): - SDP_IS_AVAILABLE = True - from torch.backends.cuda import SDPBackend, sdp_kernel - - BACKEND_MAP = { - SDPBackend.MATH: { - "enable_math": True, - "enable_flash": False, - "enable_mem_efficient": False, - }, - SDPBackend.FLASH_ATTENTION: { - "enable_math": False, - "enable_flash": True, - "enable_mem_efficient": False, - }, - SDPBackend.EFFICIENT_ATTENTION: { - "enable_math": False, - "enable_flash": False, - "enable_mem_efficient": True, - }, - None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, - } -else: - from contextlib import nullcontext - - SDP_IS_AVAILABLE = False - sdp_kernel = nullcontext - BACKEND_MAP = {} - logpy.warn( - f"No SDP backend available, likely because you are running in pytorch " - f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " - f"You might want to consider upgrading." - ) +SDP_IS_AVAILABLE = True + + +BACKEND_MAP = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, +} try: import xformers @@ -476,10 +465,8 @@ def __init__( assert attn_mode in self.ATTENTION_MODES if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: logpy.warn( - f"Attention mode '{attn_mode}' is not available. Falling " - f"back to native attention. This is not a problem in " - f"Pytorch >= 2.0. FYI, you are running with PyTorch " - f"version {torch.__version__}." + f"Attention mode '{attn_mode}' is not available. " + f"Falling back to native attention." ) attn_mode = "softmax" elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: @@ -495,10 +482,7 @@ def __init__( logpy.info("Falling back to xformers efficient attention.") attn_mode = "softmax-xformers" attn_cls = self.ATTENTION_MODES[attn_mode] - if version.parse(torch.__version__) >= version.parse("2.0.0"): - assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) - else: - assert sdp_backend is None + assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) self.disable_self_attn = disable_self_attn self.attn1 = attn_cls( query_dim=dim, diff --git a/sgm/modules/autoencoding/temporal_ae.py b/sgm/modules/autoencoding/temporal_ae.py index 4a17a911..de83ff0d 100644 --- a/sgm/modules/autoencoding/temporal_ae.py +++ b/sgm/modules/autoencoding/temporal_ae.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, Iterable, Union import torch @@ -260,10 +261,7 @@ def make_time_attn( f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" ) if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": - print( - f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " - f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" - ) + warnings.warn(f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention.") attn_type = "vanilla" if attn_type == "vanilla": diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py index 4cf9d921..b33748b9 100644 --- a/sgm/modules/diffusionmodules/model.py +++ b/sgm/modules/diffusionmodules/model.py @@ -282,12 +282,9 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): "linear", "none", ], f"attn_type {attn_type} unknown" - if ( - version.parse(torch.__version__) < version.parse("2.0.0") - and attn_type != "none" - ): + if attn_type != "none": assert XFORMERS_IS_AVAILABLE, ( - f"We do not support vanilla attention in {torch.__version__} anymore, " + f"We do not support vanilla attention anymore, " f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py index 23c7d073..56a5d099 100644 --- a/sgm/modules/diffusionmodules/wrappers.py +++ b/sgm/modules/diffusionmodules/wrappers.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from packaging import version OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" @@ -8,13 +7,10 @@ class IdentityWrapper(nn.Module): def __init__(self, diffusion_model, compile_model: bool = False): super().__init__() - compile = ( - torch.compile - if (version.parse(torch.__version__) >= version.parse("2.0.0")) - and compile_model - else lambda x: x - ) - self.diffusion_model = compile(diffusion_model) + if compile_model: + self.diffusion_model = torch.compile(diffusion_model) + else: + self.diffusion_model = diffusion_model def forward(self, *args, **kwargs): return self.diffusion_model(*args, **kwargs)