Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove more Torch version comparisons #183

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@ def str2bool(v):
default=False, # TODO: later default to True
help="log to wandb",
)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="single checkpoint file to resume from",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="single checkpoint file to resume from",
)
default_args = default_trainer_args()
for key in default_args:
parser.add_argument("--" + key, default=default_args[key])
Expand Down Expand Up @@ -618,7 +617,6 @@ def init_wandb(save_dir, opt, config, group_name, name_str):

# move before model init, in case a torch.compile(...) is called somewhere
if opt.enable_tf32:
# pt_version = version.parse(torch.__version__)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print(f"Enabling TF32 for PyTorch {torch.__version__}")
Expand Down
4 changes: 1 addition & 3 deletions sgm/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import torch.nn as nn
from einops import rearrange
from packaging import version

from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.ema import LitEma
Expand Down Expand Up @@ -43,8 +42,7 @@ def __init__(
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
self.automatic_optimization = False # pytorch lightning

def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
Expand Down
66 changes: 25 additions & 41 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,31 @@
from packaging import version
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.backends.cuda import SDPBackend, sdp_kernel

logpy = logging.getLogger(__name__)

if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel

BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}
else:
from contextlib import nullcontext

SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
logpy.warn(
f"No SDP backend available, likely because you are running in pytorch "
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
f"You might want to consider upgrading."
)
SDP_IS_AVAILABLE = True


BACKEND_MAP = {
SDPBackend.MATH: {
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False,
},
SDPBackend.FLASH_ATTENTION: {
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False,
},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True,
},
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
}

try:
import xformers
Expand Down Expand Up @@ -476,10 +465,8 @@ def __init__(
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
logpy.warn(
f"Attention mode '{attn_mode}' is not available. Falling "
f"back to native attention. This is not a problem in "
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
f"version {torch.__version__}."
f"Attention mode '{attn_mode}' is not available. "
f"Falling back to native attention."
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
Expand All @@ -495,10 +482,7 @@ def __init__(
logpy.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
Expand Down
6 changes: 2 additions & 4 deletions sgm/modules/autoencoding/temporal_ae.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable, Iterable, Union

import torch
Expand Down Expand Up @@ -260,10 +261,7 @@ def make_time_attn(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
warnings.warn(f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention.")
attn_type = "vanilla"

if attn_type == "vanilla":
Expand Down
7 changes: 2 additions & 5 deletions sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,9 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
"linear",
"none",
], f"attn_type {attn_type} unknown"
if (
version.parse(torch.__version__) < version.parse("2.0.0")
and attn_type != "none"
):
if attn_type != "none":
assert XFORMERS_IS_AVAILABLE, (
f"We do not support vanilla attention in {torch.__version__} anymore, "
f"We do not support vanilla attention anymore, "
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
Expand Down
12 changes: 4 additions & 8 deletions sgm/modules/diffusionmodules/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
import torch
import torch.nn as nn
from packaging import version

OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"


class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False):
super().__init__()
compile = (
torch.compile
if (version.parse(torch.__version__) >= version.parse("2.0.0"))
and compile_model
else lambda x: x
)
self.diffusion_model = compile(diffusion_model)
if compile_model:
self.diffusion_model = torch.compile(diffusion_model)
else:
self.diffusion_model = diffusion_model

def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
Expand Down