-
Notifications
You must be signed in to change notification settings - Fork 26
/
text_to_image.py
295 lines (255 loc) · 11 KB
/
text_to_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import logging
import os
from enum import Enum
from typing import List, Optional, Tuple
import PIL
import torch
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
FluxPipeline,
StableDiffusion3Pipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.models import AutoencoderKL
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
LoraLoader,
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
split_prompt,
)
from app.utils.errors import InferenceError
logger = logging.getLogger(__name__)
class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
SD3 = "stabilityai/stable-diffusion-3"
REALISTIC_VISION_V6 = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
FLUX_1_SCHNELL = "black-forest-labs/FLUX.1-schnell"
FLUX_1_DEV = "black-forest-labs/FLUX.1-dev"
@classmethod
def list(cls):
"""Return a list of all model IDs."""
return [model.value for model in cls]
class TextToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}
torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load the fp16 variant if fp16 'safetensors' files are present in the cache.
# NOTE: Exception for SDXL-Lightning model: despite having fp16 'safetensors'
# files, they are not named according to the standard convention.
has_fp16_variant = (
any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
or ModelName.SDXL_LIGHTNING.value in model_id
)
if torch_device.type != "cpu" and has_fp16_variant:
logger.info("TextToImagePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
if os.environ.get("BFLOAT16"):
logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16
# Load VAE for specific models.
if ModelName.REALISTIC_VISION_V6.value in model_id:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
kwargs["vae"] = vae
# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"
# ByteDance/SDXL-Lightning-2step
if "2step" in model_id:
unet_id = "sdxl_lightning_2step_unet"
# ByteDance/SDXL-Lightning-4step
elif "4step" in model_id:
unet_id = "sdxl_lightning_4step_unet"
# ByteDance/SDXL-Lightning-8step
elif "8step" in model_id:
unet_id = "sdxl_lightning_8step_unet"
else:
# Default to 8step
unet_id = "sdxl_lightning_8step_unet"
unet_config = UNet2DConditionModel.load_config(
pretrained_model_name_or_path=base,
subfolder="unet",
cache_dir=kwargs["cache_dir"],
)
unet = UNet2DConditionModel.from_config(unet_config).to(
torch_device, kwargs["torch_dtype"]
)
unet.load_state_dict(
load_file(
hf_hub_download(
ModelName.SDXL_LIGHTNING.value,
f"{unet_id}.safetensors",
cache_dir=kwargs["cache_dir"],
),
device=str(torch_device),
)
)
self.ldm = StableDiffusionXLPipeline.from_pretrained(
base, unet=unet, **kwargs
).to(torch_device)
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif ModelName.SD3.value in model_id:
self.ldm = StableDiffusion3Pipeline.from_pretrained(model_id, **kwargs).to(
torch_device
)
kwargs["torch_dtype"] = torch.bfloat16
elif (
ModelName.FLUX_1_SCHNELL.value in model_id
or ModelName.FLUX_1_DEV.value in model_id
):
# Decrease precision to preven OOM errors.
kwargs["torch_dtype"] = torch.bfloat16
self.ldm = FluxPipeline.from_pretrained(model_id, **kwargs).to(torch_device)
else:
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
torch_device
)
if os.environ.get("TORCH_COMPILE"):
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
self.ldm.unet.to(memory_format=torch.channels_last)
self.ldm.vae.to(memory_format=torch.channels_last)
self.ldm.unet = torch.compile(
self.ldm.unet, mode="max-autotune", fullgraph=True
)
self.ldm.vae.decode = torch.compile(
self.ldm.vae.decode, mode="max-autotune", fullgraph=True
)
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
logger.warning(
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended "
"as it may lead to suboptimal performance. Please disable one of them."
)
if sfast_enabled:
logger.info(
"TextToImagePipeline will be dynamically compiled with stable-fast for "
"%s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.ldm = compile_model(self.ldm)
# Warm-up the pipeline.
# TODO: Not yet supported for TextToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"TextToImagePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache
self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
self._lora_loader = LoraLoader(self.ldm)
def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
loras_json = kwargs.pop("loras", "")
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]
# Dynamically (un)load LoRas.
if not loras_json:
self._lora_loader.disable_loras()
else:
self._lora_loader.load_loras(loras_json)
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]
if (
self.model_id == "stabilityai/sdxl-turbo"
or self.model_id == "stabilityai/sd-turbo"
):
# SD turbo models were trained without guidance_scale so
# it should be set to 0
kwargs["guidance_scale"] = 0.0
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
if "2step" in self.model_id:
kwargs["num_inference_steps"] = 2
elif "4step" in self.model_id:
kwargs["num_inference_steps"] = 4
elif "8step" in self.model_id:
kwargs["num_inference_steps"] = 8
else:
# Default to 8step
kwargs["num_inference_steps"] = 8
elif (
ModelName.FLUX_1_SCHNELL.value in self.model_id
or ModelName.FLUX_1_DEV.value in self.model_id
):
# FluxPipeline does not support negative prompt per diffusers documentation
# https://github.com/huggingface/diffusers/blob/750bd7920622b3fe538d20035d3f03855c5d6621/src/diffusers/pipelines/flux/pipeline_flux.py#L537
if "negative_prompt" in kwargs:
kwargs.pop("negative_prompt")
# Allow users to specify multiple (negative) prompts using the '|' separator.
prompts = split_prompt(prompt, max_splits=3)
prompt = prompts.pop("prompt")
kwargs.update(prompts)
neg_prompts = split_prompt(
kwargs.pop("negative_prompt", ""),
key_prefix="negative_prompt",
max_splits=3,
)
kwargs.update(neg_prompts)
try:
outputs = self.ldm(prompt=prompt, **kwargs)
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)
if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(outputs.images)
return outputs.images, has_nsfw_concept
def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"