Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image generation with deepspeed --fp16 #394

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ def generate_images(

return images


@torch.autocast(device_type="cuda", enabled=True)
def forward(
self,
text,
Expand Down
11 changes: 4 additions & 7 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import io
import sys
import os
import requests
import PIL
import warnings
import hashlib
import urllib
import yaml
from pathlib import Path
from tqdm import tqdm
from math import sqrt, log
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ
import importlib


import torch
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp.autocast_mode import autocast

from einops import rearrange

Expand Down Expand Up @@ -196,6 +191,7 @@ def _register_external_parameters(self):
self, self.model.quantize.embed.weight if self.is_gumbel else self.model.quantize.embedding.weight)

@torch.no_grad()
@autocast(enabled=True, dtype=torch.float32, cache_enabled=True)
def get_codebook_indices(self, img):
b = img.shape[0]
img = (2 * img) - 1
Expand All @@ -204,6 +200,7 @@ def get_codebook_indices(self, img):
return rearrange(indices, 'b h w -> b (h w)', b=b)
return rearrange(indices, '(b n) -> b n', b = b)

@autocast(enabled=True, dtype=torch.float32, cache_enabled=True)
def decode(self, img_seq):
b, n = img_seq.shape
one_hot_indices = F.one_hot(img_seq, num_classes = self.num_tokens).float()
Expand Down
20 changes: 8 additions & 12 deletions train_dalle.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,12 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
# initialize DALL-E

dalle = DALLE(vae=vae, **dalle_params)
if not using_deepspeed:
if args.fp16:
dalle = dalle.half()
dalle = dalle.cuda()
if args.fp16:
dalle.vae.float()
for layer in dalle.modules():
if not isinstance(layer, VQGanVAE): # VQGanVAE is not FP16 compatible
layer.half()
dalle = dalle.cuda()

if RESUME and not using_deepspeed:
dalle.load_state_dict(weights)
Expand Down Expand Up @@ -505,7 +507,6 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
# Prefer scheduler in `deepspeed_config`.
if LR_DECAY and distr_scheduler is None:
distr_scheduler = scheduler
avoid_model_calls = using_deepspeed and args.fp16

if RESUME and using_deepspeed:
distr_dalle.load_checkpoint(str(cp_dir))
Expand Down Expand Up @@ -607,16 +608,11 @@ def save_model(path, epoch=0):
token_list = sample_text.masked_select(sample_text != 0).tolist()
decoded_text = tokenizer.decode(token_list)

if not avoid_model_calls:
# CUDA index errors when we don't guard this
image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9


image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9
log = {
**log,
}
if not avoid_model_calls:
log['image'] = wandb.Image(image, caption=decoded_text)
log['image'] = wandb.Image(image, caption=decoded_text)

if i % 10 == 9 and distr_backend.is_root_worker():
sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
Expand Down