Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some type errors #46

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fp8/flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import shutil
import warnings
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, Callable, List, Tuple

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -386,7 +386,7 @@ def into_bytes(self, x: torch.Tensor, jpeg_quality: int = 99) -> io.BytesIO:
return im

@torch.inference_mode()
def as_img_tensor(self, x: torch.Tensor) -> io.BytesIO:
def as_img_tensor(self, x: torch.Tensor) -> Tuple[List[Image.Image], List[np.ndarray]]:
"""Converts the image tensor to bytes."""
# bring into PIL format and save
num_images = x.shape[0]
Expand Down Expand Up @@ -550,7 +550,7 @@ def generate(
num_images: int = 1,
jpeg_quality: int = 99,
compiling: bool = False,
) -> io.BytesIO:
) -> Tuple[List[Image.Image], List[np.ndarray]]:
"""
Generate images based on the given prompt and parameters.

Expand Down
68 changes: 37 additions & 31 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Any, Tuple, Optional
from typing import Any, Tuple, Dict, Optional, cast

import torch

Expand All @@ -24,6 +24,7 @@
from torchvision import transforms
from cog import BasePredictor, Input, Path
from flux.util import load_ae, load_clip, load_flow_model, load_t5, download_weights
from flux.model import Flux as BF16Flux

from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
Expand Down Expand Up @@ -153,7 +154,9 @@ def base_setup(
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
self.feature_extractor = cast(
CLIPImageProcessor, CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
)

print("Loading Falcon safety checker...")
if not FALCON_MODEL_CACHE.exists():
Expand All @@ -162,7 +165,9 @@ def base_setup(
FALCON_MODEL_NAME,
cache_dir=FALCON_MODEL_CACHE,
)
self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME)
self.falcon_processor = cast(
ViTImageProcessor, ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME)
)

# need > 48 GB of ram to store all models in VRAM
total_mem = torch.cuda.get_device_properties(0).total_memory
Expand All @@ -175,8 +180,11 @@ def base_setup(
max_length = 256 if self.flow_model_name == "flux-schnell" else 512
self.t5 = load_t5(device, max_length=max_length)
self.clip = load_clip(device)
self.flux = load_flow_model(
self.flow_model_name, device="cpu" if self.offload else device
self.flux = cast(
BF16Flux,
load_flow_model(
self.flow_model_name, device="cpu" if self.offload else device
),
)
self.flux = self.flux.eval()
self.ae = load_ae(
Expand All @@ -195,13 +203,13 @@ def base_setup(

if not self.disable_fp8:
if compile_fp8:
extra_args = {
extra_args: Dict[str, Any] = {
"compile_whole_model": True,
"compile_extras": True,
"compile_blocks": True,
}
else:
extra_args = {
extra_args: Dict[str, Any] = {
"compile_whole_model": False,
"compile_extras": False,
"compile_blocks": False,
Expand Down Expand Up @@ -235,12 +243,11 @@ def compile_fp8(self):
num_steps=self.num_steps,
guidance=3,
seed=123,
compiling=compile,
compiling=True,
)

for k in ASPECT_RATIOS:
for k, (width, height) in ASPECT_RATIOS.items():
print(f"warming kernel for {k}")
width, height = self.aspect_ratio_to_width_height(k)
self.fp8_pipe.generate(
prompt="godzilla!", width=width, height=height, num_steps=4, guidance=3
)
Expand All @@ -261,7 +268,8 @@ def compile_bf16(self):
self.compile_run = True
self.base_predict(
prompt="a cool dog",
aspect_ratio="1:1",
width=1024,
height=1024,
num_outputs=1,
num_inference_steps=self.num_steps,
guidance=3.5,
Expand All @@ -272,20 +280,18 @@ def compile_bf16(self):
def aspect_ratio_to_width_height(self, aspect_ratio: str):
return ASPECT_RATIOS.get(aspect_ratio)

def get_image(self, image: str):
if image is None:
return None
image = Image.open(image).convert("RGB")
def get_image(self, image: str | Path) -> torch.Tensor:
image_opened = Image.open(image).convert("RGB")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: 2.0 * x - 1.0),
]
)
img: torch.Tensor = transform(image)
img: torch.Tensor = cast(torch.Tensor, transform(image_opened))
return img[None, ...]

def predict():
def predict(self) -> List[Path]:
raise Exception("You need to instantiate a predictor for a specific flux model")

@torch.inference_mode()
Expand Down Expand Up @@ -323,7 +329,7 @@ def handle_loras(
unload_loras(model)

def preprocess(self, aspect_ratio: str, megapixels: str = "1") -> Tuple[int, int]:
width, height = ASPECT_RATIOS.get(aspect_ratio)
width, height = ASPECT_RATIOS[aspect_ratio]
if megapixels == "0.25":
width, height = width // 2, height // 2

Expand All @@ -336,12 +342,12 @@ def base_predict(
num_outputs: int,
num_inference_steps: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
image: Optional[Path] = None, # img2img for flux-dev
prompt_strength: float = 0.8,
seed: int = None,
seed: Optional[int] = None,
width: int = 1024,
height: int = 1024,
) -> List[Path]:
) -> Tuple[List[Image.Image], List[np.ndarray]]:
"""Run a single prediction on the model"""
torch_device = torch.device("cuda")
init_image = None
Expand Down Expand Up @@ -406,7 +412,7 @@ def base_predict(
if self.offload:
self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
torch.cuda.empty_cache()
self.flux = self.flux.to(torch_device)
self.flux = cast(BF16Flux, self.flux.to(torch_device))

x, flux = denoise(
self.flux,
Expand Down Expand Up @@ -449,12 +455,12 @@ def fp8_predict(
num_outputs: int,
num_inference_steps: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
image: Optional[Path] = None, # img2img for flux-dev
prompt_strength: float = 0.8,
seed: int = None,
seed: Optional[int] = None,
width: int = 1024,
height: int = 1024,
) -> List[Image]:
) -> Tuple[List[Image.Image], List[np.ndarray]]:
"""Run a single prediction on the model"""
print("running quantized prediction")

Expand All @@ -465,18 +471,18 @@ def fp8_predict(
num_steps=num_inference_steps,
guidance=guidance,
seed=seed,
init_image=image,
init_image=str(image) if image else None,
strength=prompt_strength,
num_images=num_outputs,
)

def postprocess(
self,
images: List[Image],
images: List[Image.Image],
disable_safety_checker: bool,
output_format: str,
output_quality: int,
np_images: Optional[List[Image]] = None,
np_images: Optional[List[np.ndarray]] = None,
) -> List[Path]:
has_nsfw_content = [False] * len(images)

Expand Down Expand Up @@ -542,12 +548,12 @@ def shared_predict(
num_outputs: int,
num_inference_steps: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
image: Optional[Path] = None, # img2img for flux-dev
prompt_strength: float = 0.8,
seed: int = None,
seed: Optional[int] = None,
width: int = 1024,
height: int = 1024,
):
) -> Tuple[List[Image.Image], Optional[List[np.ndarray]]]:
if go_fast and not self.disable_fp8:
return self.fp8_predict(
prompt=prompt,
Expand Down
Loading