Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stuff to save Vram to less then 7gb #29

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 66 additions & 8 deletions OmniGen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import math
from typing import Dict
import torch.nn.functional as F

from diffusers.loaders import PeftAdapterMixin
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
Expand All @@ -16,7 +17,7 @@

def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class TimestepEmbedder(nn.Module):
"""
Expand Down Expand Up @@ -149,19 +150,54 @@ def forward(self, x):
return x


class Int8Quantized(nn.Module):
def __init__(self, tensor, scale_factor=None):
super().__init__()
if scale_factor is None:
max_val = torch.max(torch.abs(tensor))
scale_factor = max_val / 127.0
# Store quantized weights and scale factor
self.register_buffer('quantized_weight', torch.round(tensor / scale_factor).to(torch.int8))
self.register_buffer('scale_factor', torch.tensor(scale_factor))

def forward(self, dtype=None):
# Dequantize and convert to specified dtype
weight = self.quantized_weight.float() * self.scale_factor
if dtype is not None:
weight = weight.to(dtype)
return weight



class QuantizedLinear(nn.Module):
def __init__(self, weight, bias=None):
super().__init__()
self.weight_quantized = Int8Quantized(weight)
if bias is not None:
self.register_buffer('bias', bias)
else:
self.bias = None

def forward(self, x):
# Dequantize weight to match input dtype
weight = self.weight_quantized(dtype=x.dtype)
return F.linear(x, weight, self.bias)


class OmniGen(nn.Module, PeftAdapterMixin):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
transformer_config: Phi3Config,
transformer_config=Phi3Config,
patch_size=2,
in_channels=4,
pe_interpolation: float = 1.0,
pos_embed_max_size: int = 192,
):
super().__init__()

self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
Expand All @@ -174,7 +210,7 @@ def __init__(

self.time_token = TimestepEmbedder(hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)

self.pe_interpolation = pe_interpolation
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
Expand All @@ -185,24 +221,46 @@ def __init__(

self.llm = Phi3Transformer(config=transformer_config)
self.llm.config.use_cache = False


def _quantize_module(self, module):
"""
Quantize a module to 8-bit precision
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, QuantizedLinear(child.weight.data, child.bias.data if child.bias is not None else None))
elif isinstance(child, nn.LayerNorm):
# Skip quantization for LayerNorm
continue
else:
self._quantize_module(child)

@classmethod
def from_pretrained(cls, model_name):
def from_pretrained(cls, model_name, quantize=False): # Add quantize parameter
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])

config = Phi3Config.from_pretrained(model_name)
model = cls(config)

if os.path.exists(os.path.join(model_name, 'model.safetensors')):
print("Loading safetensors")
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
else:
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')

# Load weights first
model.load_state_dict(ckpt)
return model

# Only quantize if explicitly requested
if quantize:
print("Quantizing weights to 8-bit...")
model._quantize_module(model.llm)

return model
def initialize_weights(self):
assert not hasattr(self, "llama")

Expand Down
127 changes: 92 additions & 35 deletions OmniGen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler

import gc # For clearing unused objects

logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__)

EXAMPLE_DOC_STRING = """
Examples:
Expand All @@ -40,14 +41,13 @@
```
"""



class OmniGenPipeline:
def __init__(
self,
vae: AutoencoderKL,
model: OmniGen,
processor: OmniGenProcessor,

):
self.vae = vae
self.model = model
Expand All @@ -59,34 +59,38 @@ def __init__(
self.vae.to(self.device)

@classmethod
def from_pretrained(cls, model_name, vae_path: str=None):
def from_pretrained(cls, model_name, vae_path: str=None, Quantization: bool=False):
if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"):
logger.info("Model not found, downloading...")
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'])
logger.info(f"Downloaded model to {model_name}")
model = OmniGen.from_pretrained(model_name)

# Pass Quantization parameter to OmniGen's from_pretrained
model = OmniGen.from_pretrained(model_name, quantize=Quantization)

processor = OmniGenProcessor.from_pretrained(model_name)

if os.path.exists(os.path.join(model_name, "vae")):
vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
elif vae_path is not None:
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
vae = AutoencoderKL.from_pretrained(vae_path)
else:
logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF")
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")

return cls(vae, model, processor)




def merge_lora(self, lora_path: str):
model = PeftModel.from_pretrained(self.model, lora_path)
model.merge_and_unload()


self.model = model

def to(self, device: Union[str, torch.device]):
if isinstance(device, str):
device = torch.device(device)
Expand All @@ -101,7 +105,7 @@ def vae_encode(self, x, dtype):
x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
x = x.to(dtype)
return x

def move_to_device(self, data):
if isinstance(data, list):
return [x.to(self.device) for x in data]
Expand All @@ -124,13 +128,15 @@ def __call__(
use_kv_cache: bool = True,
dtype: torch.dtype = torch.bfloat16,
seed: int = None,
Quantization: bool = False,
):

r"""
Function invoked when calling the pipeline for generation.

Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
The prompt or prompts to guide the image generation.
input_images (`List[str]` or `List[List[str]]`, *optional*):
The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
height (`int`, *optional*, defaults to 1024):
Expand All @@ -146,9 +152,9 @@ def __call__(
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
use_img_guidance (`bool`, *optional*, defaults to True):
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
img_guidance_scale (`float`, *optional*, defaults to 1.6):
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
separate_cfg_infer (`bool`, *optional*, defaults to False):
Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
Expand All @@ -160,6 +166,8 @@ def __call__(
Returns:
A list with the generated images.
"""


assert height%16 == 0 and width%16 == 0
if separate_cfg_infer:
use_kv_cache = False
Expand All @@ -170,63 +178,112 @@ def __call__(
prompt = [prompt]
input_images = [input_images] if input_images is not None else None


input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer)

num_prompt = len(prompt)
num_cfg = 2 if use_img_guidance else 1
latent_size_h, latent_size_w = height//8, width//8
latent_size_h, latent_size_w = height // 8, width // 8

if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None
latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator)
latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype)


# Load VAE into VRAM (GPU) in bfloat16
self.vae.to(self.device, dtype=torch.bfloat16)




input_img_latents = []
if separate_cfg_infer:
for temp_pixel_values in input_data['input_pixel_values']:
temp_input_latents = []
for img in temp_pixel_values:
img = self.vae_encode(img.to(self.device), dtype)
img = self.vae_encode(img.to(self.device, dtype=torch.bfloat16), dtype)

temp_input_latents.append(img)
input_img_latents.append(temp_input_latents)
else:
for img in input_data['input_pixel_values']:
img = self.vae_encode(img.to(self.device), dtype)
img = self.vae_encode(img.to(self.device, dtype=torch.bfloat16), dtype)

input_img_latents.append(img)

model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
input_img_latents=input_img_latents,
input_image_sizes=input_data['input_image_sizes'],
attention_mask=self.move_to_device(input_data["attention_mask"]),
position_ids=self.move_to_device(input_data["position_ids"]),


model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
input_img_latents=input_img_latents,
input_image_sizes=input_data['input_image_sizes'],
attention_mask=self.move_to_device(input_data["attention_mask"]),
position_ids=self.move_to_device(input_data["position_ids"]),
cfg_scale=guidance_scale,
img_cfg_scale=img_guidance_scale,
use_img_cfg=use_img_guidance,
use_kv_cache=use_kv_cache)



#unlode vae to cpu
self.vae.to('cpu')
torch.cuda.empty_cache() # Clear VRAM
gc.collect() # Run garbage collection to free system RAM



if separate_cfg_infer:
func = self.model.forward_with_separate_cfg
else:
func = self.model.forward_with_cfg
self.model.to(dtype)


#move main model to gpu
self.model.to(self.device, dtype=dtype)


scheduler = OmniGenScheduler(num_steps=num_inference_steps)
samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache)
samples = samples.chunk((1+num_cfg), dim=0)[0]
samples = samples.chunk((1 + num_cfg), dim=0)[0]

samples = samples.to(torch.float32)
if self.vae.config.shift_factor is not None:
samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
else:
samples = samples / self.vae.config.scaling_factor
samples = samples / self.vae.config.scaling_factor

#unlode main model to cpu
self.model.to('cpu')
torch.cuda.empty_cache() # Clear VRAM
gc.collect() # Run garbage collection to free system RAM

# Move samples to GPU and ensure they are in bfloat16 (for the VAE)
samples = samples.to(self.device, dtype=torch.bfloat16)

# Load VAE into VRAM (GPU) in bfloat16
self.vae.to(self.device, dtype=torch.bfloat16)

# Decode the samples using the VAE
samples = self.vae.decode(samples).sample

output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255

#unlode vae to cpu
self.vae.to('cpu')
torch.cuda.empty_cache() # Clear VRAM
gc.collect() # Run garbage collection to free system RAM


# Convert samples back to float32 for further processing
samples = samples.to(torch.float32)


# Convert samples to uint8 for final image output
output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255
output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
# Create output images
output_images = []
for i, sample in enumerate(output_samples):
for i, sample in enumerate(output_samples):
output_images.append(Image.fromarray(sample))

return output_images

# Return the generated images
return output_images
Loading