From a3539a374a4f21753afb9b1f8dae285f4748d12e Mon Sep 17 00:00:00 2001 From: AmirMasoud Nourollah Date: Fri, 5 Aug 2022 21:03:47 +0430 Subject: [PATCH] Change some code to Pythonic way. --- stylegan2_pytorch/cli.py | 195 ++++++----- stylegan2_pytorch/diff_augment.py | 26 +- stylegan2_pytorch/stylegan2_pytorch.py | 454 ++++++++++++++++--------- 3 files changed, 406 insertions(+), 269 deletions(-) diff --git a/stylegan2_pytorch/cli.py b/stylegan2_pytorch/cli.py index 03f9863..063caad 100644 --- a/stylegan2_pytorch/cli.py +++ b/stylegan2_pytorch/cli.py @@ -13,14 +13,17 @@ import numpy as np + def cast_list(el): return el if isinstance(el, list) else [el] -def timestamped_filename(prefix = 'generated-'): + +def timestamped_filename(prefix='generated-'): now = datetime.now() timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") return f'{prefix}{timestamp}' + def set_seed(seed): torch.manual_seed(seed) torch.backends.cudnn.deterministic = True @@ -28,6 +31,7 @@ def set_seed(seed): np.random.seed(seed) random.seed(seed) + def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed): is_main = rank == 0 is_ddp = world_size > 1 @@ -41,9 +45,9 @@ def run_training(rank, world_size, model_args, data, load_from, new, num_train_s print(f"{rank + 1}/{world_size} process initialized.") model_args.update( - is_ddp = is_ddp, - rank = rank, - world_size = world_size + is_ddp=is_ddp, + rank=rank, + world_size=world_size ) model = Trainer(**model_args) @@ -55,7 +59,7 @@ def run_training(rank, world_size, model_args, data, load_from, new, num_train_s model.set_data_src(data) - progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>') + progress_bar = tqdm(initial=model.steps, total=num_train_steps, mininterval=10., desc=f'{name}<{data}>') while model.steps < num_train_steps: retry_call(model.train, tries=3, exceptions=NanException) progress_bar.n = model.steps @@ -68,94 +72,95 @@ def run_training(rank, world_size, model_args, data, load_from, new, num_train_s if is_ddp: dist.destroy_process_group() + def train_from_folder( - data = './data', - results_dir = './results', - models_dir = './models', - name = 'default', - new = False, - load_from = -1, - image_size = 128, - network_capacity = 16, - fmap_max = 512, - transparent = False, - batch_size = 5, - gradient_accumulate_every = 6, - num_train_steps = 150000, - learning_rate = 2e-4, - lr_mlp = 0.1, - ttur_mult = 1.5, - rel_disc_loss = False, - num_workers = None, - save_every = 1000, - evaluate_every = 1000, - generate = False, - num_generate = 1, - generate_interpolation = False, - interpolation_num_steps = 100, - save_frames = False, - num_image_tiles = 8, - trunc_psi = 0.75, - mixed_prob = 0.9, - fp16 = False, - no_pl_reg = False, - cl_reg = False, - fq_layers = [], - fq_dict_size = 256, - attn_layers = [], - no_const = False, - aug_prob = 0., - aug_types = ['translation', 'cutout'], - top_k_training = False, - generator_top_k_gamma = 0.99, - generator_top_k_frac = 0.5, - dual_contrast_loss = False, - dataset_aug_prob = 0., - multi_gpus = False, - calculate_fid_every = None, - calculate_fid_num_images = 12800, - clear_fid_cache = False, - seed = 42, - log = False + data='./data', + results_dir='./results', + models_dir='./models', + name='default', + new=False, + load_from=-1, + image_size=128, + network_capacity=16, + fmap_max=512, + transparent=False, + batch_size=5, + gradient_accumulate_every=6, + num_train_steps=150000, + learning_rate=2e-4, + lr_mlp=0.1, + ttur_mult=1.5, + rel_disc_loss=False, + num_workers=None, + save_every=1000, + evaluate_every=1000, + generate=False, + num_generate=1, + generate_interpolation=False, + interpolation_num_steps=100, + save_frames=False, + num_image_tiles=8, + trunc_psi=0.75, + mixed_prob=0.9, + fp16=False, + no_pl_reg=False, + cl_reg=False, + fq_layers=[], + fq_dict_size=256, + attn_layers=[], + no_const=False, + aug_prob=0., + aug_types=['translation', 'cutout'], + top_k_training=False, + generator_top_k_gamma=0.99, + generator_top_k_frac=0.5, + dual_contrast_loss=False, + dataset_aug_prob=0., + multi_gpus=False, + calculate_fid_every=None, + calculate_fid_num_images=12800, + clear_fid_cache=False, + seed=42, + log=False ): model_args = dict( - name = name, - results_dir = results_dir, - models_dir = models_dir, - batch_size = batch_size, - gradient_accumulate_every = gradient_accumulate_every, - image_size = image_size, - network_capacity = network_capacity, - fmap_max = fmap_max, - transparent = transparent, - lr = learning_rate, - lr_mlp = lr_mlp, - ttur_mult = ttur_mult, - rel_disc_loss = rel_disc_loss, - num_workers = num_workers, - save_every = save_every, - evaluate_every = evaluate_every, - num_image_tiles = num_image_tiles, - trunc_psi = trunc_psi, - fp16 = fp16, - no_pl_reg = no_pl_reg, - cl_reg = cl_reg, - fq_layers = fq_layers, - fq_dict_size = fq_dict_size, - attn_layers = attn_layers, - no_const = no_const, - aug_prob = aug_prob, - aug_types = cast_list(aug_types), - top_k_training = top_k_training, - generator_top_k_gamma = generator_top_k_gamma, - generator_top_k_frac = generator_top_k_frac, - dual_contrast_loss = dual_contrast_loss, - dataset_aug_prob = dataset_aug_prob, - calculate_fid_every = calculate_fid_every, - calculate_fid_num_images = calculate_fid_num_images, - clear_fid_cache = clear_fid_cache, - mixed_prob = mixed_prob, - log = log + name=name, + results_dir=results_dir, + models_dir=models_dir, + batch_size=batch_size, + gradient_accumulate_every=gradient_accumulate_every, + image_size=image_size, + network_capacity=network_capacity, + fmap_max=fmap_max, + transparent=transparent, + lr=learning_rate, + lr_mlp=lr_mlp, + ttur_mult=ttur_mult, + rel_disc_loss=rel_disc_loss, + num_workers=num_workers, + save_every=save_every, + evaluate_every=evaluate_every, + num_image_tiles=num_image_tiles, + trunc_psi=trunc_psi, + fp16=fp16, + no_pl_reg=no_pl_reg, + cl_reg=cl_reg, + fq_layers=fq_layers, + fq_dict_size=fq_dict_size, + attn_layers=attn_layers, + no_const=no_const, + aug_prob=aug_prob, + aug_types=cast_list(aug_types), + top_k_training=top_k_training, + generator_top_k_gamma=generator_top_k_gamma, + generator_top_k_frac=generator_top_k_frac, + dual_contrast_loss=dual_contrast_loss, + dataset_aug_prob=dataset_aug_prob, + calculate_fid_every=calculate_fid_every, + calculate_fid_num_images=calculate_fid_num_images, + clear_fid_cache=clear_fid_cache, + mixed_prob=mixed_prob, + log=log ) if generate: @@ -171,7 +176,8 @@ def train_from_folder( model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() - model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames) + model.generate_interpolation(samples_name, num_image_tiles, num_steps=interpolation_num_steps, + save_frames=save_frames) print(f'interpolation generated at {results_dir}/{name}/{samples_name}') return @@ -182,9 +188,10 @@ def train_from_folder( return mp.spawn(run_training, - args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed), - nprocs=world_size, - join=True) + args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed), + nprocs=world_size, + join=True) + def main(): fire.Fire(train_from_folder) diff --git a/stylegan2_pytorch/diff_augment.py b/stylegan2_pytorch/diff_augment.py index 53efa54..ff6cf1f 100644 --- a/stylegan2_pytorch/diff_augment.py +++ b/stylegan2_pytorch/diff_augment.py @@ -24,16 +24,21 @@ def rand_brightness(x, scale): x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * scale return x + def rand_saturation(x, scale): x_mean = x.mean(dim=1, keepdim=True) - x = (x - x_mean) * (((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean + x = (x - x_mean) * ( + ((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean return x + def rand_contrast(x, scale): x_mean = x.mean(dim=[1, 2, 3], keepdim=True) - x = (x - x_mean) * (((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean + x = (x - x_mean) * ( + ((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean return x + def rand_translation(x, ratio=0.125): shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) @@ -49,11 +54,12 @@ def rand_translation(x, ratio=0.125): x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) return x + def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1): w, h = x.size(2), x.size(3) imgs = [] - for img in x.unbind(dim = 0): + for img in x.unbind(dim=0): max_h = int(w * ratio * ratio_h) max_v = int(h * ratio * ratio_v) @@ -70,12 +76,15 @@ def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1): return torch.stack(imgs) + def rand_offset_h(x, ratio=1): return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0) + def rand_offset_v(x, ratio=1): return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio) + def rand_cutout(x, ratio=0.5): cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) @@ -92,15 +101,18 @@ def rand_cutout(x, ratio=0.5): x = x * mask.unsqueeze(1) return x + AUGMENT_FNS = { 'brightness': [partial(rand_brightness, scale=1.)], 'lightbrightness': [partial(rand_brightness, scale=.65)], - 'contrast': [partial(rand_contrast, scale=.5)], - 'lightcontrast': [partial(rand_contrast, scale=.25)], + 'contrast': [partial(rand_contrast, scale=.5)], + 'lightcontrast': [partial(rand_contrast, scale=.25)], 'saturation': [partial(rand_saturation, scale=1.)], 'lightsaturation': [partial(rand_saturation, scale=.5)], - 'color': [partial(rand_brightness, scale=1.), partial(rand_saturation, scale=1.), partial(rand_contrast, scale=0.5)], - 'lightcolor': [partial(rand_brightness, scale=0.65), partial(rand_saturation, scale=.5), partial(rand_contrast, scale=0.5)], + 'color': [partial(rand_brightness, scale=1.), partial(rand_saturation, scale=1.), + partial(rand_contrast, scale=0.5)], + 'lightcolor': [partial(rand_brightness, scale=0.65), partial(rand_saturation, scale=.5), + partial(rand_contrast, scale=0.5)], 'offset': [rand_offset], 'offset_h': [rand_offset_h], 'offset_v': [rand_offset_v], diff --git a/stylegan2_pytorch/stylegan2_pytorch.py b/stylegan2_pytorch/stylegan2_pytorch.py index d3f50dc..65bc631 100644 --- a/stylegan2_pytorch/stylegan2_pytorch.py +++ b/stylegan2_pytorch/stylegan2_pytorch.py @@ -38,6 +38,7 @@ try: from apex import amp + APEX_AVAILABLE = True except: APEX_AVAILABLE = False @@ -46,59 +47,66 @@ assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' - # constants NUM_CORES = multiprocessing.cpu_count() EXTS = ['jpg', 'jpeg', 'png'] + # helper classes class NanException(Exception): pass -class EMA(): + +class EMA: def __init__(self, beta): super().__init__() self.beta = beta + def update_average(self, old, new): - if not exists(old): - return new - return old * self.beta + (1 - self.beta) * new + return new if not exists(old) else old * self.beta + (1 - self.beta) * new + class Flatten(nn.Module): def forward(self, x): return x.reshape(x.shape[0], -1) + class RandomApply(nn.Module): - def __init__(self, prob, fn, fn_else = lambda x: x): + def __init__(self, prob, fn, fn_else=lambda x: x): super().__init__() self.fn = fn self.fn_else = fn_else self.prob = prob + def forward(self, x): fn = self.fn if random() < self.prob else self.fn_else return fn(x) + class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn + def forward(self, x): return self.fn(x) + x + class ChanNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): + def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) def forward(self, x): - var = torch.var(x, dim = 1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = 1, keepdim = True) + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -108,67 +116,76 @@ def __init__(self, dim, fn): def forward(self, x): return self.fn(self.norm(x)) + class PermuteToFrom(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn + def forward(self, x): x = x.permute(0, 2, 3, 1) out, *_, loss = self.fn(x) out = out.permute(0, 3, 1, 2) return out, loss + class Blur(nn.Module): def __init__(self): super().__init__() f = torch.Tensor([1, 2, 1]) self.register_buffer('f', f) + def forward(self, x): f = self.f - f = f[None, None, :] * f [None, :, None] + f = f[None, None, :] * f[None, :, None] return filter2d(x, f, normalized=True) + # attention class DepthWiseConv2d(nn.Module): - def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True): + def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True): super().__init__() self.net = nn.Sequential( - nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), - nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) + nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, + bias=bias), + nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias) ) + def forward(self, x): return self.net(x) + class LinearAttention(nn.Module): - def __init__(self, dim, dim_head = 64, heads = 8): + def __init__(self, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads inner_dim = dim_head * heads self.nonlin = nn.GELU() - self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) - self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) + self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False) + self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False) self.to_out = nn.Conv2d(inner_dim, dim, 1) def forward(self, fmap): h, x, y = self.heads, *fmap.shape[-2:] - q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) - q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) + q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1)) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=h), (q, k, v)) - q = q.softmax(dim = -1) - k = k.softmax(dim = -2) + q = q.softmax(dim=-1) + k = k.softmax(dim=-2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) out = einsum('b n d, b d e -> b n e', q, context) - out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) + out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h=h, x=x, y=y) out = self.nonlin(out) return self.to_out(out) + # one layer of self-attention and feedforward, for images attn_and_ff = lambda chan: nn.Sequential(*[ @@ -176,48 +193,55 @@ def forward(self, fmap): Residual(PreNorm(chan, nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) ]) + # helpers def exists(val): return val is not None + @contextmanager def null_context(): yield + def combine_contexts(contexts): @contextmanager def multi_contexts(): with ExitStack() as stack: yield [stack.enter_context(ctx()) for ctx in contexts] + return multi_contexts + def default(value, d): return value if exists(value) else d + def cycle(iterable): while True: - for i in iterable: - yield i + yield from iterable + def cast_list(el): return el if isinstance(el, list) else [el] + def is_empty(t): - if isinstance(t, torch.Tensor): - return t.nelement() == 0 - return not exists(t) + return t.nelement() == 0 if isinstance(t, torch.Tensor) else not exists(t) + def raise_if_nan(t): if torch.isnan(t): raise NanException + def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): if is_ddp: num_no_syncs = gradient_accumulate_every - 1 head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs tail = [null_context] - contexts = head + tail + contexts = head + tail else: contexts = [null_context] * gradient_accumulate_every @@ -225,6 +249,7 @@ def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): with context(): yield + def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): if fp16: with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: @@ -232,7 +257,8 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): else: loss.backward(**kwargs) -def gradient_penalty(images, output, weight = 10): + +def gradient_penalty(images, output, weight=10): batch_size = images.shape[0] gradients = torch_grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=images.device), @@ -241,6 +267,7 @@ def gradient_penalty(images, output, weight = 10): gradients = gradients.reshape(batch_size, -1) return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + def calc_pl_lengths(styles, images): device = images.device num_pixels = images.shape[2] * images.shape[3] @@ -253,25 +280,32 @@ def calc_pl_lengths(styles, images): return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() + def noise(n, latent_dim, device): return torch.randn(n, latent_dim).cuda(device) + def noise_list(n, layers, latent_dim, device): return [(noise(n, latent_dim, device), layers)] + def mixed_list(n, layers, latent_dim, device): tt = int(torch.rand(()).numpy() * layers) return noise_list(n, tt, latent_dim, device) + noise_list(n, layers - tt, latent_dim, device) + def latent_to_w(style_vectorizer, latent_descr): return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr] + def image_noise(n, im_size, device): return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) + def leaky_relu(p=0.2): return nn.LeakyReLU(p, inplace=True) + def evaluate_in_chunks(max_batch_size, model, *args): split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) chunked_outputs = [model(*i) for i in split_args] @@ -279,52 +313,56 @@ def evaluate_in_chunks(max_batch_size, model, *args): return chunked_outputs[0] return torch.cat(chunked_outputs, dim=0) + def styles_def_to_tensor(styles_def): return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) + def set_requires_grad(model, bool): for p in model.parameters(): p.requires_grad = bool + def slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res + return (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + # losses def gen_hinge_loss(fake, real): return fake.mean() + def hinge_loss(real, fake): return (F.relu(1 + real) + F.relu(1 - fake)).mean() + def dual_contrastive_loss(real_logits, fake_logits): device = real_logits.device real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) def loss_half(t1, t2): t1 = rearrange(t1, 'i -> i ()') - t2 = repeat(t2, 'j -> i j', i = t1.shape[0]) - t = torch.cat((t1, t2), dim = -1) - return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long)) + t2 = repeat(t2, 'j -> i j', i=t1.shape[0]) + t = torch.cat((t1, t2), dim=-1) + return F.cross_entropy(t, torch.zeros(t1.shape[0], device=device, dtype=torch.long)) return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) + # dataset def convert_rgb_to_transparent(image): - if image.mode != 'RGBA': - return image.convert('RGBA') - return image + return image.convert('RGBA') if image.mode != 'RGBA' else image + def convert_transparent_to_rgb(image): - if image.mode != 'RGB': - return image.convert('RGB') - return image + return image.convert('RGB') if image.mode != 'RGB' else image + class expand_greyscale(object): def __init__(self, transparent): @@ -351,27 +389,35 @@ def __call__(self, tensor): return color if not self.transparent else torch.cat((color, alpha)) + def resize_to_minimum_size(min_size, image): if max(*image.size) < min_size: return torchvision.transforms.functional.resize(image, min_size) return image + class Dataset(data.Dataset): - def __init__(self, folder, image_size, transparent = False, aug_prob = 0.): + def __init__(self, + folder, + image_size, + transparent=False, + aug_prob=0.): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')] - assert len(self.paths) > 0, f'No images were found in {folder} for training' + assert self.paths, f'No images were found in {folder} for training' + + convert_image_fn = convert_rgb_to_transparent if transparent else convert_transparent_to_rgb - convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent - num_channels = 3 if not transparent else 4 + num_channels = 4 if transparent else 3 self.transform = transforms.Compose([ transforms.Lambda(convert_image_fn), transforms.Lambda(partial(resize_to_minimum_size, image_size)), transforms.Resize(image_size), - RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)), + RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), + transforms.CenterCrop(image_size)), transforms.ToTensor(), transforms.Lambda(expand_greyscale(transparent)) ]) @@ -384,32 +430,33 @@ def __getitem__(self, index): img = Image.open(path) return self.transform(img) + # augmentations def random_hflip(tensor, prob): - if prob < random(): - return tensor - return torch.flip(tensor, dims=(3,)) + return tensor if prob < random() else torch.flip(tensor, dims=(3,)) + class AugWrapper(nn.Module): def __init__(self, D, image_size): super().__init__() self.D = D - def forward(self, images, prob = 0., types = [], detach = False): + def forward(self, images, prob=0.0, types=None, detach=False): + if types is None: + types = [] if random() < prob: images = random_hflip(images, prob=0.5) images = DiffAugment(images, types=types) - if detach: images = images.detach() - return self.D(images) + # stylegan2 classes class EqualLinear(nn.Module): - def __init__(self, in_dim, out_dim, lr_mul = 1, bias = True): + def __init__(self, in_dim, out_dim, lr_mul=1, bias=True): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) if bias: @@ -420,8 +467,9 @@ def __init__(self, in_dim, out_dim, lr_mul = 1, bias = True): def forward(self, input): return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) + class StyleVectorizer(nn.Module): - def __init__(self, emb, depth, lr_mul = 0.1): + def __init__(self, emb, depth, lr_mul=0.1): super().__init__() layers = [] @@ -434,8 +482,9 @@ def forward(self, x): x = F.normalize(x, dim=1) return self.net(x) + class RGBBlock(nn.Module): - def __init__(self, latent_dim, input_channel, upsample, rgba = False): + def __init__(self, latent_dim, input_channel, upsample, rgba=False): super().__init__() self.input_channel = input_channel self.to_style = nn.Linear(latent_dim, input_channel) @@ -444,7 +493,7 @@ def __init__(self, latent_dim, input_channel, upsample, rgba = False): self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) self.upsample = nn.Sequential( - nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), Blur() ) if upsample else None @@ -461,8 +510,9 @@ def forward(self, x, prev_rgb, istyle): return x + class Conv2DMod(nn.Module): - def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps = 1e-8, **kwargs): + def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs): super().__init__() self.filters = out_chan self.demod = demod @@ -473,7 +523,8 @@ def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, self.eps = eps nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') - def _get_same_padding(self, size, kernel, dilation, stride): + @staticmethod + def _get_same_padding(size, kernel, dilation, stride): return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 def forward(self, x, y): @@ -498,15 +549,16 @@ def forward(self, x, y): x = x.reshape(-1, self.filters, h, w) return x + class GeneratorBlock(nn.Module): - def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False): + def __init__(self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, rgba=False): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None self.to_style1 = nn.Linear(latent_dim, input_channels) self.to_noise1 = nn.Linear(1, filters) self.conv1 = Conv2DMod(input_channels, filters, 3) - + self.to_style2 = nn.Linear(latent_dim, filters) self.to_noise2 = nn.Linear(1, filters) self.conv2 = Conv2DMod(filters, filters, 3) @@ -533,10 +585,11 @@ def forward(self, x, prev_rgb, istyle, inoise): rgb = self.to_rgb(x, prev_rgb, istyle) return x, rgb + class DiscriminatorBlock(nn.Module): def __init__(self, input_channels, filters, downsample=True): super().__init__() - self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) + self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1)) self.net = nn.Sequential( nn.Conv2d(input_channels, filters, 3, padding=1), @@ -547,7 +600,7 @@ def __init__(self, input_channels, filters, downsample=True): self.downsample = nn.Sequential( Blur(), - nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) + nn.Conv2d(filters, filters, 3, padding=1, stride=2) ) if downsample else None def forward(self, x): @@ -558,8 +611,19 @@ def forward(self, x): x = (x + res) * (1 / math.sqrt(2)) return x + class Generator(nn.Module): - def __init__(self, image_size, latent_dim, network_capacity = 16, transparent = False, attn_layers = [], no_const = False, fmap_max = 512): + def __init__(self, + image_size, + latent_dim, + network_capacity=16, + transparent=False, + attn_layers=None, + no_const=False, + fmap_max=512): + if attn_layers is None: + attn_layers = [] + super().__init__() self.image_size = image_size self.latent_dim = latent_dim @@ -597,9 +661,9 @@ def __init__(self, image_size, latent_dim, network_capacity = 16, transparent = latent_dim, in_chan, out_chan, - upsample = not_first, - upsample_rgb = not_last, - rgba = transparent + upsample=not_first, + upsample_rgb=not_last, + rgba=transparent ) self.blocks.append(block) @@ -624,11 +688,13 @@ def forward(self, styles, input_noise): return rgb + class Discriminator(nn.Module): - def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_size = 256, attn_layers = [], transparent = False, fmap_max = 512): + def __init__(self, image_size, network_capacity=16, fq_layers=[], fq_dict_size=256, attn_layers=[], + transparent=False, fmap_max=512): super().__init__() num_layers = int(log2(image_size) - 1) - num_init_filters = 3 if not transparent else 4 + num_init_filters = 4 if transparent else 3 blocks = [] filters = [num_init_filters] + [(network_capacity * 4) * (2 ** i) for i in range(num_layers + 1)] @@ -645,7 +711,7 @@ def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_si num_layer = ind + 1 is_not_last = ind != (len(chan_in_out) - 1) - block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last) + block = DiscriminatorBlock(in_chan, out_chan, downsample=is_not_last) blocks.append(block) attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None @@ -686,19 +752,44 @@ def forward(self, x): x = self.to_logit(x) return x.squeeze(), quantize_loss + class StyleGAN2(nn.Module): - def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0): + def __init__(self, + image_size, + latent_dim=512, + fmap_max=512, + style_depth=8, + network_capacity=16, + transparent=False, + fp16=False, + cl_reg=False, + steps=1, + lr=1e-4, + ttur_mult=2, + fq_layers=None, + fq_dict_size=256, + attn_layers=None, + no_const=False, + lr_mlp=0.1, + rank=0): super().__init__() + if attn_layers is None: + attn_layers = [] + if fq_layers is None: + fq_layers = [] self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) - self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp) - self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max) - self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max) + self.S = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) + self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers, + no_const=no_const, fmap_max=fmap_max) + self.D = Discriminator(image_size, network_capacity, fq_layers=fq_layers, fq_dict_size=fq_dict_size, + attn_layers=attn_layers, transparent=transparent, fmap_max=fmap_max) - self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp) - self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const) + self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul=lr_mlp) + self.GE = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers, + no_const=no_const) self.D_cl = None @@ -717,8 +808,8 @@ def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8 # init optimizers generator_params = list(self.G.parameters()) + list(self.S.parameters()) - self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9)) - self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9)) + self.G_opt = Adam(generator_params, lr=self.lr, betas=(0.5, 0.9)) + self.D_opt = Adam(self.D.parameters(), lr=self.lr * ttur_mult, betas=(0.5, 0.9)) # init weights self._init_weights() @@ -729,7 +820,8 @@ def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8 # startup apex mixed precision self.fp16 = fp16 if fp16: - (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1', num_losses=3) + (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize( + [self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1', num_losses=3) def _init_weights(self): for m in self.modules(): @@ -758,53 +850,60 @@ def reset_parameter_averaging(self): def forward(self, x): return x + class Trainer(): def __init__( - self, - name = 'default', - results_dir = 'results', - models_dir = 'models', - base_dir = './', - image_size = 128, - network_capacity = 16, - fmap_max = 512, - transparent = False, - batch_size = 4, - mixed_prob = 0.9, - gradient_accumulate_every=1, - lr = 2e-4, - lr_mlp = 0.1, - ttur_mult = 2, - rel_disc_loss = False, - num_workers = None, - save_every = 1000, - evaluate_every = 1000, - num_image_tiles = 8, - trunc_psi = 0.6, - fp16 = False, - cl_reg = False, - no_pl_reg = False, - fq_layers = [], - fq_dict_size = 256, - attn_layers = [], - no_const = False, - aug_prob = 0., - aug_types = ['translation', 'cutout'], - top_k_training = False, - generator_top_k_gamma = 0.99, - generator_top_k_frac = 0.5, - dual_contrast_loss = False, - dataset_aug_prob = 0., - calculate_fid_every = None, - calculate_fid_num_images = 12800, - clear_fid_cache = False, - is_ddp = False, - rank = 0, - world_size = 1, - log = False, - *args, - **kwargs + self, + name='default', + results_dir='results', + models_dir='models', + base_dir='./', + image_size=128, + network_capacity=16, + fmap_max=512, + transparent=False, + batch_size=4, + mixed_prob=0.9, + gradient_accumulate_every=1, + lr=2e-4, + lr_mlp=0.1, + ttur_mult=2, + rel_disc_loss=False, + num_workers=None, + save_every=1000, + evaluate_every=1000, + num_image_tiles=8, + trunc_psi=0.6, + fp16=False, + cl_reg=False, + no_pl_reg=False, + fq_layers=None, + fq_dict_size=256, + attn_layers=None, + no_const=False, + aug_prob=0., + aug_types=None, + top_k_training=False, + generator_top_k_gamma=0.99, + generator_top_k_frac=0.5, + dual_contrast_loss=False, + dataset_aug_prob=0., + calculate_fid_every=None, + calculate_fid_num_images=12800, + clear_fid_cache=False, + is_ddp=False, + rank=0, + world_size=1, + log=False, + *args, + **kwargs ): + if fq_layers is None: + fq_layers = [] + if attn_layers is None: + attn_layers = [] + if aug_types is None: + aug_types = ['translation', 'cutout'] self.GAN_params = [args, kwargs] self.GAN = None @@ -892,7 +991,7 @@ def __init__( @property def image_extension(self): - return 'jpg' if not self.transparent else 'png' + return 'png' if self.transparent else 'jpg' @property def checkpoint_num(self): @@ -901,10 +1000,14 @@ def checkpoint_num(self): @property def hparams(self): return {'image_size': self.image_size, 'network_capacity': self.network_capacity} - + def init_GAN(self): args, kwargs = self.GAN_params - self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs) + self.GAN = StyleGAN2(lr=self.lr, lr_mlp=self.lr_mlp, ttur_mult=self.ttur_mult, image_size=self.image_size, + network_capacity=self.network_capacity, fmap_max=self.fmap_max, + transparent=self.transparent, fq_layers=self.fq_layers, fq_dict_size=self.fq_dict_size, + attn_layers=self.attn_layers, fp16=self.fp16, cl_reg=self.cl_reg, no_const=self.no_const, + rank=self.rank, *args, **kwargs) if self.is_ddp: ddp_kwargs = {'device_ids': [self.rank]} @@ -920,7 +1023,8 @@ def write_config(self): self.config_path.write_text(json.dumps(self.config())) def load_config(self): - config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) + config = json.loads(self.config_path.read_text()) if self.config_path.exists() else self.config() + self.image_size = config['image_size'] self.network_capacity = config['network_capacity'] self.transparent = config['transparent'] @@ -934,13 +1038,18 @@ def load_config(self): self.init_GAN() def config(self): - return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'lr_mlp': self.lr_mlp, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const} + return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'lr_mlp': self.lr_mlp, + 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, + 'attn_layers': self.attn_layers, 'no_const': self.no_const} def set_data_src(self, folder): - self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob) - num_workers = num_workers = default(self.num_workers, NUM_CORES if not self.is_ddp else 0) - sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None - dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True) + self.dataset = Dataset(folder, self.image_size, transparent=self.transparent, aug_prob=self.dataset_aug_prob) + num_workers = default(self.num_workers, 0 if self.is_ddp else NUM_CORES) + sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, + shuffle=True) if self.is_ddp else None + dataloader = data.DataLoader(self.dataset, num_workers=num_workers, + batch_size=math.ceil(self.batch_size / self.world_size), sampler=sampler, + shuffle=not self.is_ddp, drop_last=True, pin_memory=True) self.loader = cycle(dataloader) # auto set augmentation prob for user if dataset is detected to be low @@ -965,8 +1074,8 @@ def train(self): latent_dim = self.GAN.G.latent_dim num_layers = self.GAN.G.num_layers - aug_prob = self.aug_prob - aug_types = self.aug_types + aug_prob = self.aug_prob + aug_types = self.aug_types aug_kwargs = {'prob': aug_prob, 'types': aug_types} apply_gradient_penalty = self.steps % 4 == 0 @@ -1001,7 +1110,7 @@ def train(self): loss = self.GAN.D_cl.calculate_loss() self.last_cr_loss = loss.clone().detach().item() - backwards(loss, self.GAN.D_opt, loss_id = 0) + backwards(loss, self.GAN.D_opt, loss_id=0) self.GAN.D_opt.step() @@ -1030,7 +1139,7 @@ def train(self): w_styles = styles_def_to_tensor(w_space) generated_images = G(w_styles, noise) - fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs) + fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach=True, **aug_kwargs) image_batch = next(self.loader).cuda(self.rank) image_batch.requires_grad_() @@ -1060,7 +1169,7 @@ def train(self): disc_loss = disc_loss / self.gradient_accumulate_every disc_loss.register_hook(raise_if_nan) - backwards(disc_loss, self.GAN.D_opt, loss_id = 1) + backwards(disc_loss, self.GAN.D_opt, loss_id=1) total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every @@ -1087,7 +1196,7 @@ def train(self): real_output = None if G_requires_reals: image_batch = next(self.loader).cuda(self.rank) - real_output, _ = D_aug(image_batch, detach = True, **aug_kwargs) + real_output, _ = D_aug(image_batch, detach=True, **aug_kwargs) real_output = real_output.detach() if self.top_k_training: @@ -1112,7 +1221,7 @@ def train(self): gen_loss = gen_loss / self.gradient_accumulate_every gen_loss.register_hook(raise_if_nan) - backwards(gen_loss, self.GAN.G_opt, loss_id = 2) + backwards(gen_loss, self.GAN.G_opt, loss_id=2) total_gen_loss += loss.detach().item() / self.gradient_accumulate_every @@ -1161,11 +1270,11 @@ def train(self): self.av = None @torch.no_grad() - def evaluate(self, num = 0, trunc = 1.0): + def evaluate(self, num=0, trunc=1.0): self.GAN.eval() ext = self.image_extension num_rows = self.num_image_tiles - + latent_dim = self.GAN.G.latent_dim image_size = self.GAN.G.image_size num_layers = self.GAN.G.num_layers @@ -1177,13 +1286,15 @@ def evaluate(self, num = 0, trunc = 1.0): # regular - generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi) - torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) - + generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi=self.trunc_psi) + torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), + nrow=num_rows) + # moving averages - generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi) - torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) + generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi=self.trunc_psi) + torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), + nrow=num_rows) # mixing regularities @@ -1192,7 +1303,8 @@ def tile(a, dim, n_tile): repeat_idx = [1] * a.dim() repeat_idx[dim] = n_tile a = a.repeat(*(repeat_idx)) - order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(self.rank) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(self.rank) return torch.index_select(a, dim, order_index) nn = noise(num_rows, latent_dim, device=self.rank) @@ -1202,8 +1314,9 @@ def tile(a, dim, n_tile): tt = int(num_layers / 2) mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)] - generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi) - torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows) + generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi=self.trunc_psi) + torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), + nrow=num_rows) @torch.no_grad() def calculate_fid(self, num_batches): @@ -1243,15 +1356,17 @@ def calculate_fid(self, num_batches): noise = image_noise(self.batch_size, image_size, device=self.rank) # moving averages - generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, noise, trunc_psi = self.trunc_psi) + generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, noise, + trunc_psi=self.trunc_psi) for j, image in enumerate(generated_images.unbind(0)): - torchvision.utils.save_image(image, str(fake_path / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) + torchvision.utils.save_image(image, + str(fake_path / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, noise.device, 2048) @torch.no_grad() - def truncate_style(self, tensor, trunc_psi = 0.75): + def truncate_style(self, tensor, trunc_psi=0.75): S = self.GAN.S batch_size = self.batch_size latent_dim = self.GAN.G.latent_dim @@ -1259,31 +1374,31 @@ def truncate_style(self, tensor, trunc_psi = 0.75): if not exists(self.av): z = noise(2000, latent_dim, device=self.rank) samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy() - self.av = np.mean(samples, axis = 0) - self.av = np.expand_dims(self.av, axis = 0) + self.av = np.mean(samples, axis=0) + self.av = np.expand_dims(self.av, axis=0) av_torch = torch.from_numpy(self.av).cuda(self.rank) tensor = trunc_psi * (tensor - av_torch) + av_torch return tensor @torch.no_grad() - def truncate_style_defs(self, w, trunc_psi = 0.75): + def truncate_style_defs(self, w, trunc_psi=0.75): w_space = [] for tensor, num_layers in w: - tensor = self.truncate_style(tensor, trunc_psi = trunc_psi) + tensor = self.truncate_style(tensor, trunc_psi=trunc_psi) w_space.append((tensor, num_layers)) return w_space @torch.no_grad() - def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8): + def generate_truncated(self, S, G, style, noi, trunc_psi=0.75, num_image_tiles=8): w = map(lambda t: (S(t[0]), t[1]), style) - w_truncated = self.truncate_style_defs(w, trunc_psi = trunc_psi) + w_truncated = self.truncate_style_defs(w, trunc_psi=trunc_psi) w_styles = styles_def_to_tensor(w_truncated) generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi) return generated_images.clamp_(0., 1.) @torch.no_grad() - def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_steps = 100, save_frames = False): + def generate_interpolation(self, num=0, num_image_tiles=8, trunc=1.0, num_steps=100, save_frames=False): self.GAN.eval() ext = self.image_extension num_rows = num_image_tiles @@ -1304,17 +1419,18 @@ def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_ for ratio in tqdm(ratios): interp_latents = slerp(ratio, latents_low, latents_high) latents = [(interp_latents, num_layers)] - generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi) - images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows) + generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi=self.trunc_psi) + images_grid = torchvision.utils.make_grid(generated_images, nrow=num_rows) pil_image = transforms.ToPILImage()(images_grid.cpu()) - + if self.transparent: background = Image.new("RGBA", pil_image.size, (255, 255, 255)) pil_image = Image.alpha_composite(background, pil_image) - + frames.append(pil_image) - frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True) + frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], + duration=80, loop=0, optimize=True) if save_frames: folder_path = (self.results_dir / self.name / f'{str(num)}') @@ -1340,7 +1456,7 @@ def print_log(self): def track(self, value, name): if not exists(self.logger): return - self.logger.track(value, name = name) + self.logger.track(value, name=name) def model_name(self, num): return str(self.models_dir / self.name / f'model_{num}.pt') @@ -1368,7 +1484,7 @@ def save(self, num): torch.save(save_data, self.model_name(num)) self.write_config() - def load(self, num = -1): + def load(self, num=-1): self.load_config() name = num @@ -1390,17 +1506,19 @@ def load(self, num = -1): try: self.GAN.load_state_dict(load_data['GAN']) except Exception as e: - print('unable to load save model. please try downgrading the package to the version specified by the saved model') + print( + 'unable to load save model. please try downgrading the package to the version specified by the saved model') raise e if self.GAN.fp16 and 'amp' in load_data: amp.load_state_dict(load_data['amp']) + class ModelLoader: - def __init__(self, *, base_dir, name = 'default', load_from = -1): - self.model = Trainer(name = name, base_dir = base_dir) + def __init__(self, *, base_dir, name='default', load_from=-1): + self.model = Trainer(name=name, base_dir=base_dir) self.model.load(load_from) - def noise_to_styles(self, noise, trunc_psi = None): + def noise_to_styles(self, noise, trunc_psi=None): noise = noise.cuda() w = self.model.GAN.SE(noise) if exists(trunc_psi): @@ -1414,7 +1532,7 @@ def styles_to_images(self, w): w_def = [(w, num_layers)] w_tensors = styles_def_to_tensor(w_def) - noise = image_noise(batch_size, image_size, device = 0) + noise = image_noise(batch_size, image_size, device=0) images = self.model.GAN.GE(w_tensors, noise) images.clamp_(0., 1.)