-
Notifications
You must be signed in to change notification settings - Fork 18
/
predict.py
124 lines (114 loc) · 5.07 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import cog
import random
import tempfile
from pathlib import Path
import torch
from PIL import Image
import clip
import torchvision
from main import load_vqgan_model, CLIP_DIM, clamp_with_grad, synth, load_clip_model, load_model, load_prior_model
from download_weights import PRIOR_MODEL
MODELS = [
"cc12m_32x1024_vitgan_clip_ViTB32_256x256_v0.1.th",
"cc12m_32x1024_vitgan_clip_ViTB32_256x256_v0.2.th",
"cc12m_32x1024_mlp_mixer_clip_ViTB32_256x256_v0.2.th",
"cc12m_32x1024_mlp_mixer_clip_ViTB32_256x256_v0.3.th",
"cc12m_32x1024_mlp_mixer_cloob_rn50_256x256_v0.3.th",
"cc12m_256x16_xtransformer_clip_ViTB32_512x512_v0.3.th",
"cc12m_32x1024_mlp_mixer_clip_ViTB32_pixelrecons_256x256_v0.4.th",
"cc12m_32x1024_mlp_mixer_openclip_laion2b_ViTB32_256x256_v0.4.th",
"cc12m_32x1024_mlp_mixer_openclip_laion2b_imgEmb_ViTB32_256x256_v0.4.th",
"cc12m_1x1024_mlp_mixer_openclip_laion2b_ViTB32_512x512_v0.4.th",
]
DEFAULT_MODEL = "cc12m_32x1024_mlp_mixer_openclip_laion2b_ViTB32_256x256_v0.4.th"
GRID_SIZES = [
"1x1",
"2x2",
"4x4",
"8x8",
]
def download_and_load_model(model_path):
download_from_path(model_path)
return load_model(model_path)
def download_and_load_prior_model(model_path):
download_from_path(model_path)
return load_prior_model(model_path)
def download_from_path(path):
from download_weights import download, model_url
download(model_url[path], target=path)
class Predictor(cog.BasePredictor):
def setup(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.nets = {
model_path: download_and_load_model(model_path).cpu()
for model_path in MODELS
}
self.priors = {}
self.perceptors = {}
self.vqgans = {}
for path, net in self.nets.items():
config = net.config
vqgan_config = config.vqgan_config
vqgan_checkpoint = config.vqgan_checkpoint
clip_model = config.clip_model
clip_model_path = config.get("clip_model_path")
# Load CLIP model if not already done
if (clip_model, clip_model_path) not in self.perceptors:
perceptor = load_clip_model(clip_model, path=clip_model_path).eval().requires_grad_(False).to(self.device)
self.perceptors[(clip_model, clip_model_path)] = perceptor
# Load VQGAN model if not already done
if (vqgan_config, vqgan_checkpoint) not in self.vqgans:
model = load_vqgan_model(vqgan_config, vqgan_checkpoint).to(self.device)
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
self.vqgans[(vqgan_config, vqgan_checkpoint)] = model, z_min, z_max
# Load PRIOR model if not already done
if PRIOR_MODEL[path] not in self.priors:
self.priors[PRIOR_MODEL[path]] = download_and_load_prior_model(PRIOR_MODEL[path]).cpu()
def predict(
self,
prompt:str=cog.Input(description="prompt for generating image"),
model:str=cog.Input(description="Model", default=DEFAULT_MODEL,choices=MODELS+["random"]),
prior:bool=cog.Input(description="Use prior", default=False),
grid_size:str=cog.Input(description="Grid size", default="1x1", choices=GRID_SIZES),
seed:int=cog.Input(description="Seed", default=0)
) -> cog.Path:
torch.manual_seed(seed)
random.seed(seed)
if model == "random":
model = random.choice(list(self.nets.keys()))
net = self.nets[model]
net.to(self.device)
config = net.config
clip_model = config.clip_model
clip_model_path = config.get("clip_model_path")
vqgan_config = config.vqgan_config
vqgan_checkpoint = config.vqgan_checkpoint
grid_size_h, grid_size_w = grid_size.split("x")
grid_size_h = int(grid_size_h)
grid_size_w = int(grid_size_w)
toks = clip.tokenize([prompt], truncate=True)
perceptor = self.perceptors[(clip_model, clip_model_path)]
vqgan, z_min, z_max = self.vqgans[(vqgan_config, vqgan_checkpoint)]
if prior:
prior_model = self.priors[PRIOR_MODEL[model]]
prior_model.to(self.device)
with torch.no_grad():
H = perceptor.encode_text(toks.to(self.device)).float()
H = H.repeat(grid_size_h*grid_size_w, 1)
if prior:
H = H.view(len(H), -1, 1, 1)
H = prior_model.sample(H)
H = H.view(len(H), -1)
z = net(H)
z = clamp_with_grad(z, z_min.min(), z_max.max())
xr = synth(vqgan, z)
grid = torchvision.utils.make_grid(xr.cpu(), nrow=grid_size_h)
out_path = cog.Path(tempfile.mkdtemp()) / "out.png"
torchvision.transforms.functional.to_pil_image(grid).save(out_path)
net.cpu()
if prior:
prior_model.cpu()
return out_path