From 83986f818a306ca8fee49985ae3adbd40b835278 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Mon, 11 Nov 2024 18:43:14 +0100 Subject: [PATCH] Fix some type errors --- fp8/flux_pipeline.py | 6 ++-- predict.py | 68 ++++++++++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/fp8/flux_pipeline.py b/fp8/flux_pipeline.py index c208abf..41bffb9 100644 --- a/fp8/flux_pipeline.py +++ b/fp8/flux_pipeline.py @@ -4,7 +4,7 @@ import random import shutil import warnings -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, Callable, List, Tuple import numpy as np from PIL import Image @@ -386,7 +386,7 @@ def into_bytes(self, x: torch.Tensor, jpeg_quality: int = 99) -> io.BytesIO: return im @torch.inference_mode() - def as_img_tensor(self, x: torch.Tensor) -> io.BytesIO: + def as_img_tensor(self, x: torch.Tensor) -> Tuple[List[Image.Image], List[np.ndarray]]: """Converts the image tensor to bytes.""" # bring into PIL format and save num_images = x.shape[0] @@ -550,7 +550,7 @@ def generate( num_images: int = 1, jpeg_quality: int = 99, compiling: bool = False, - ) -> io.BytesIO: + ) -> Tuple[List[Image.Image], List[np.ndarray]]: """ Generate images based on the given prompt and parameters. diff --git a/predict.py b/predict.py index 98cf382..9d320a8 100644 --- a/predict.py +++ b/predict.py @@ -1,6 +1,6 @@ import os import time -from typing import Any, Tuple, Optional +from typing import Any, Tuple, Dict, Optional, cast import torch @@ -24,6 +24,7 @@ from torchvision import transforms from cog import BasePredictor, Input, Path from flux.util import load_ae, load_clip, load_flow_model, load_t5, download_weights +from flux.model import Flux as BF16Flux from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, @@ -153,7 +154,9 @@ def base_setup( self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( SAFETY_CACHE, torch_dtype=torch.float16 ).to("cuda") - self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) + self.feature_extractor = cast( + CLIPImageProcessor, CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) + ) print("Loading Falcon safety checker...") if not FALCON_MODEL_CACHE.exists(): @@ -162,7 +165,9 @@ def base_setup( FALCON_MODEL_NAME, cache_dir=FALCON_MODEL_CACHE, ) - self.falcon_processor = ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME) + self.falcon_processor = cast( + ViTImageProcessor, ViTImageProcessor.from_pretrained(FALCON_MODEL_NAME) + ) # need > 48 GB of ram to store all models in VRAM total_mem = torch.cuda.get_device_properties(0).total_memory @@ -175,8 +180,11 @@ def base_setup( max_length = 256 if self.flow_model_name == "flux-schnell" else 512 self.t5 = load_t5(device, max_length=max_length) self.clip = load_clip(device) - self.flux = load_flow_model( - self.flow_model_name, device="cpu" if self.offload else device + self.flux = cast( + BF16Flux, + load_flow_model( + self.flow_model_name, device="cpu" if self.offload else device + ), ) self.flux = self.flux.eval() self.ae = load_ae( @@ -195,13 +203,13 @@ def base_setup( if not self.disable_fp8: if compile_fp8: - extra_args = { + extra_args: Dict[str, Any] = { "compile_whole_model": True, "compile_extras": True, "compile_blocks": True, } else: - extra_args = { + extra_args: Dict[str, Any] = { "compile_whole_model": False, "compile_extras": False, "compile_blocks": False, @@ -235,12 +243,11 @@ def compile_fp8(self): num_steps=self.num_steps, guidance=3, seed=123, - compiling=compile, + compiling=True, ) - for k in ASPECT_RATIOS: + for k, (width, height) in ASPECT_RATIOS.items(): print(f"warming kernel for {k}") - width, height = self.aspect_ratio_to_width_height(k) self.fp8_pipe.generate( prompt="godzilla!", width=width, height=height, num_steps=4, guidance=3 ) @@ -261,7 +268,8 @@ def compile_bf16(self): self.compile_run = True self.base_predict( prompt="a cool dog", - aspect_ratio="1:1", + width=1024, + height=1024, num_outputs=1, num_inference_steps=self.num_steps, guidance=3.5, @@ -272,20 +280,18 @@ def compile_bf16(self): def aspect_ratio_to_width_height(self, aspect_ratio: str): return ASPECT_RATIOS.get(aspect_ratio) - def get_image(self, image: str): - if image is None: - return None - image = Image.open(image).convert("RGB") + def get_image(self, image: str | Path) -> torch.Tensor: + image_opened = Image.open(image).convert("RGB") transform = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda(lambda x: 2.0 * x - 1.0), ] ) - img: torch.Tensor = transform(image) + img: torch.Tensor = cast(torch.Tensor, transform(image_opened)) return img[None, ...] - def predict(): + def predict(self) -> List[Path]: raise Exception("You need to instantiate a predictor for a specific flux model") @torch.inference_mode() @@ -323,7 +329,7 @@ def handle_loras( unload_loras(model) def preprocess(self, aspect_ratio: str, megapixels: str = "1") -> Tuple[int, int]: - width, height = ASPECT_RATIOS.get(aspect_ratio) + width, height = ASPECT_RATIOS[aspect_ratio] if megapixels == "0.25": width, height = width // 2, height // 2 @@ -336,12 +342,12 @@ def base_predict( num_outputs: int, num_inference_steps: int, guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default - image: Path = None, # img2img for flux-dev + image: Optional[Path] = None, # img2img for flux-dev prompt_strength: float = 0.8, - seed: int = None, + seed: Optional[int] = None, width: int = 1024, height: int = 1024, - ) -> List[Path]: + ) -> Tuple[List[Image.Image], List[np.ndarray]]: """Run a single prediction on the model""" torch_device = torch.device("cuda") init_image = None @@ -406,7 +412,7 @@ def base_predict( if self.offload: self.t5, self.clip = self.t5.cpu(), self.clip.cpu() torch.cuda.empty_cache() - self.flux = self.flux.to(torch_device) + self.flux = cast(BF16Flux, self.flux.to(torch_device)) x, flux = denoise( self.flux, @@ -449,12 +455,12 @@ def fp8_predict( num_outputs: int, num_inference_steps: int, guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default - image: Path = None, # img2img for flux-dev + image: Optional[Path] = None, # img2img for flux-dev prompt_strength: float = 0.8, - seed: int = None, + seed: Optional[int] = None, width: int = 1024, height: int = 1024, - ) -> List[Image]: + ) -> Tuple[List[Image.Image], List[np.ndarray]]: """Run a single prediction on the model""" print("running quantized prediction") @@ -465,18 +471,18 @@ def fp8_predict( num_steps=num_inference_steps, guidance=guidance, seed=seed, - init_image=image, + init_image=str(image) if image else None, strength=prompt_strength, num_images=num_outputs, ) def postprocess( self, - images: List[Image], + images: List[Image.Image], disable_safety_checker: bool, output_format: str, output_quality: int, - np_images: Optional[List[Image]] = None, + np_images: Optional[List[np.ndarray]] = None, ) -> List[Path]: has_nsfw_content = [False] * len(images) @@ -542,12 +548,12 @@ def shared_predict( num_outputs: int, num_inference_steps: int, guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default - image: Path = None, # img2img for flux-dev + image: Optional[Path] = None, # img2img for flux-dev prompt_strength: float = 0.8, - seed: int = None, + seed: Optional[int] = None, width: int = 1024, height: int = 1024, - ): + ) -> Tuple[List[Image.Image], Optional[List[np.ndarray]]]: if go_fast and not self.disable_fp8: return self.fp8_predict( prompt=prompt,