diff --git a/OmniGen/model.py b/OmniGen/model.py index 167aa75..f8e0f15 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -5,6 +5,7 @@ import numpy as np import math from typing import Dict +import torch.nn.functional as F from diffusers.loaders import PeftAdapterMixin from timm.models.vision_transformer import PatchEmbed, Attention, Mlp @@ -16,7 +17,7 @@ def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - + class TimestepEmbedder(nn.Module): """ @@ -149,19 +150,54 @@ def forward(self, x): return x +class Int8Quantized(nn.Module): + def __init__(self, tensor, scale_factor=None): + super().__init__() + if scale_factor is None: + max_val = torch.max(torch.abs(tensor)) + scale_factor = max_val / 127.0 + # Store quantized weights and scale factor + self.register_buffer('quantized_weight', torch.round(tensor / scale_factor).to(torch.int8)) + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + + def forward(self, dtype=None): + # Dequantize and convert to specified dtype + weight = self.quantized_weight.float() * self.scale_factor + if dtype is not None: + weight = weight.to(dtype) + return weight + + + +class QuantizedLinear(nn.Module): + def __init__(self, weight, bias=None): + super().__init__() + self.weight_quantized = Int8Quantized(weight) + if bias is not None: + self.register_buffer('bias', bias) + else: + self.bias = None + + def forward(self, x): + # Dequantize weight to match input dtype + weight = self.weight_quantized(dtype=x.dtype) + return F.linear(x, weight, self.bias) + + class OmniGen(nn.Module, PeftAdapterMixin): """ Diffusion model with a Transformer backbone. """ def __init__( self, - transformer_config: Phi3Config, + transformer_config=Phi3Config, patch_size=2, in_channels=4, pe_interpolation: float = 1.0, pos_embed_max_size: int = 192, ): super().__init__() + self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size @@ -174,7 +210,7 @@ def __init__( self.time_token = TimestepEmbedder(hidden_size) self.t_embedder = TimestepEmbedder(hidden_size) - + self.pe_interpolation = pe_interpolation pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True) @@ -185,24 +221,46 @@ def __init__( self.llm = Phi3Transformer(config=transformer_config) self.llm.config.use_cache = False - + + def _quantize_module(self, module): + """ + Quantize a module to 8-bit precision + """ + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, QuantizedLinear(child.weight.data, child.bias.data if child.bias is not None else None)) + elif isinstance(child, nn.LayerNorm): + # Skip quantization for LayerNorm + continue + else: + self._quantize_module(child) + @classmethod - def from_pretrained(cls, model_name): + def from_pretrained(cls, model_name, quantize=False): # Add quantize parameter if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) + cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) + config = Phi3Config.from_pretrained(model_name) model = cls(config) + if os.path.exists(os.path.join(model_name, 'model.safetensors')): print("Loading safetensors") ckpt = load_file(os.path.join(model_name, 'model.safetensors')) else: ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') + + # Load weights first model.load_state_dict(ckpt) - return model + # Only quantize if explicitly requested + if quantize: + print("Quantizing weights to 8-bit...") + model._quantize_module(model.llm) + + return model def initialize_weights(self): assert not hasattr(self, "llama") diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 978bdf0..cd9d0a1 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -20,8 +20,9 @@ from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler +import gc # For clearing unused objects -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) EXAMPLE_DOC_STRING = """ Examples: @@ -40,14 +41,13 @@ ``` """ - - class OmniGenPipeline: def __init__( self, vae: AutoencoderKL, model: OmniGen, processor: OmniGenProcessor, + ): self.vae = vae self.model = model @@ -59,34 +59,38 @@ def __init__( self.vae.to(self.device) @classmethod - def from_pretrained(cls, model_name, vae_path: str=None): + def from_pretrained(cls, model_name, vae_path: str=None, Quantization: bool=False): if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): logger.info("Model not found, downloading...") cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) + cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) logger.info(f"Downloaded model to {model_name}") - model = OmniGen.from_pretrained(model_name) + + # Pass Quantization parameter to OmniGen's from_pretrained + model = OmniGen.from_pretrained(model_name, quantize=Quantization) + processor = OmniGenProcessor.from_pretrained(model_name) if os.path.exists(os.path.join(model_name, "vae")): vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae")) elif vae_path is not None: - vae = AutoencoderKL.from_pretrained(vae_path).to(device) + vae = AutoencoderKL.from_pretrained(vae_path) else: logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") - vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) + vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") return cls(vae, model, processor) - + + + def merge_lora(self, lora_path: str): model = PeftModel.from_pretrained(self.model, lora_path) model.merge_and_unload() - self.model = model - + def to(self, device: Union[str, torch.device]): if isinstance(device, str): device = torch.device(device) @@ -101,7 +105,7 @@ def vae_encode(self, x, dtype): x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor) x = x.to(dtype) return x - + def move_to_device(self, data): if isinstance(data, list): return [x.to(self.device) for x in data] @@ -124,13 +128,15 @@ def __call__( use_kv_cache: bool = True, dtype: torch.dtype = torch.bfloat16, seed: int = None, + Quantization: bool = False, ): + r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + The prompt or prompts to guide the image generation. input_images (`List[str]` or `List[List[str]]`, *optional*): The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list. height (`int`, *optional*, defaults to 1024): @@ -146,9 +152,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. use_img_guidance (`bool`, *optional*, defaults to True): - Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). img_guidance_scale (`float`, *optional*, defaults to 1.6): - Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). separate_cfg_infer (`bool`, *optional*, defaults to False): Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference. use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference @@ -160,6 +166,8 @@ def __call__( Returns: A list with the generated images. """ + + assert height%16 == 0 and width%16 == 0 if separate_cfg_infer: use_kv_cache = False @@ -170,63 +178,112 @@ def __call__( prompt = [prompt] input_images = [input_images] if input_images is not None else None + input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer) num_prompt = len(prompt) num_cfg = 2 if use_img_guidance else 1 - latent_size_h, latent_size_w = height//8, width//8 + latent_size_h, latent_size_w = height // 8, width // 8 if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) else: generator = None latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator) - latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype) + latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype) + + + # Load VAE into VRAM (GPU) in bfloat16 + self.vae.to(self.device, dtype=torch.bfloat16) + + + input_img_latents = [] if separate_cfg_infer: for temp_pixel_values in input_data['input_pixel_values']: temp_input_latents = [] for img in temp_pixel_values: - img = self.vae_encode(img.to(self.device), dtype) + img = self.vae_encode(img.to(self.device, dtype=torch.bfloat16), dtype) + temp_input_latents.append(img) input_img_latents.append(temp_input_latents) else: for img in input_data['input_pixel_values']: - img = self.vae_encode(img.to(self.device), dtype) + img = self.vae_encode(img.to(self.device, dtype=torch.bfloat16), dtype) + input_img_latents.append(img) - model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']), - input_img_latents=input_img_latents, - input_image_sizes=input_data['input_image_sizes'], - attention_mask=self.move_to_device(input_data["attention_mask"]), - position_ids=self.move_to_device(input_data["position_ids"]), + + + model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']), + input_img_latents=input_img_latents, + input_image_sizes=input_data['input_image_sizes'], + attention_mask=self.move_to_device(input_data["attention_mask"]), + position_ids=self.move_to_device(input_data["position_ids"]), cfg_scale=guidance_scale, img_cfg_scale=img_guidance_scale, use_img_cfg=use_img_guidance, use_kv_cache=use_kv_cache) - + + + #unlode vae to cpu + self.vae.to('cpu') + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + + if separate_cfg_infer: func = self.model.forward_with_separate_cfg else: func = self.model.forward_with_cfg - self.model.to(dtype) + + + #move main model to gpu + self.model.to(self.device, dtype=dtype) + scheduler = OmniGenScheduler(num_steps=num_inference_steps) samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache) - samples = samples.chunk((1+num_cfg), dim=0)[0] + samples = samples.chunk((1 + num_cfg), dim=0)[0] - samples = samples.to(torch.float32) if self.vae.config.shift_factor is not None: samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor else: - samples = samples / self.vae.config.scaling_factor + samples = samples / self.vae.config.scaling_factor + + #unlode main model to cpu + self.model.to('cpu') + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + # Move samples to GPU and ensure they are in bfloat16 (for the VAE) + samples = samples.to(self.device, dtype=torch.bfloat16) + + # Load VAE into VRAM (GPU) in bfloat16 + self.vae.to(self.device, dtype=torch.bfloat16) + + # Decode the samples using the VAE samples = self.vae.decode(samples).sample - - output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255 + + #unlode vae to cpu + self.vae.to('cpu') + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + + # Convert samples back to float32 for further processing + samples = samples.to(torch.float32) + + + # Convert samples to uint8 for final image output + output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255 output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + # Create output images output_images = [] - for i, sample in enumerate(output_samples): + for i, sample in enumerate(output_samples): output_images.append(Image.fromarray(sample)) - - return output_images \ No newline at end of file + + # Return the generated images + return output_images diff --git a/app.py b/app.py index 19970ee..fd35bc8 100644 --- a/app.py +++ b/app.py @@ -2,36 +2,67 @@ from PIL import Image import os import spaces +from threading import Lock from OmniGen import OmniGenPipeline -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1" -) - -@spaces.GPU(duration=120) -# 示例处理函数:生成图像 -def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps, seed): - input_images = [img1, img2, img3] - # 去除 None - input_images = [img for img in input_images if img is not None] - if len(input_images) == 0: +class OmniGenManager: + def __init__(self): + self.pipe = None + self.lock = Lock() + self.current_quantization = None + + def get_pipeline(self, quantization: bool) -> OmniGenPipeline: + """ + Get or initialize the pipeline with the specified quantization setting. + Uses a lock to ensure thread safety. + """ + with self.lock: + # Only reinitialize if quantization setting changed or pipeline doesn't exist + if self.pipe is None or self.current_quantization != quantization: + # Clear any existing pipeline + if self.pipe is not None: + del self.pipe + self.pipe = None + + # Initialize new pipeline + self.pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1", + Quantization=quantization + ) + self.current_quantization = quantization + + return self.pipe + +# Create a single instance of the manager +pipeline_manager = OmniGenManager() + +@spaces.GPU(duration=180) +def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, quantization): + # Process input images + input_images = [img for img in [img1, img2, img3] if img is not None] + if not input_images: input_images = None + # Get or initialize pipeline with current settings + pipe = pipeline_manager.get_pipeline(quantization) + + # Generate image output = pipe( prompt=text, input_images=input_images, height=height, width=width, guidance_scale=guidance_scale, - img_guidance_scale=1.6, + img_guidance_scale=img_guidance_scale, num_inference_steps=inference_steps, separate_cfg_infer=True, use_kv_cache=False, seed=seed, ) - img = output[0] - return img + + return output[0] + # def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps): # input_images = [] # if img1: @@ -204,6 +235,10 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, infe label="Inference Steps", minimum=1, maximum=100, value=50, step=1 ) + Quantization = gr.Checkbox( + label="Low VRAM (8-bit Quantization)", value=False + ) + seed_input = gr.Slider( label="Seed", minimum=0, maximum=2147483647, value=42, step=1 ) @@ -228,6 +263,7 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, infe guidance_scale_input, num_inference_steps, seed_input, + Quantization, ], outputs=output_image, ) @@ -245,9 +281,10 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, infe guidance_scale_input, num_inference_steps, seed_input, + Quantization, ], outputs=output_image, ) # 启动应用 -demo.launch() \ No newline at end of file +demo.launch()