Skip to content

Commit

Permalink
compatibility with diffusers 0.31.0
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 22, 2024
1 parent 5df09e1 commit 52f40af
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
22 changes: 17 additions & 5 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)
from optimum.utils.normalized_config import NormalizedConfig, NormalizedTextConfig, NormalizedVisionConfig

from ...intel.utils.import_utils import _transformers_version, is_transformers_version
from ...intel.utils.import_utils import _transformers_version, is_diffusers_version, is_transformers_version
from .model_patcher import (
AquilaModelPatcher,
ArcticModelPatcher,
Expand Down Expand Up @@ -1681,7 +1681,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
img_ids_height = self.height // 2
img_ids_width = self.width // 2
return self.random_int_tensor(
[self.batch_size, img_ids_height * img_ids_width, 3],
[self.batch_size, img_ids_height * img_ids_width, 3]
if is_diffusers_version("<", "0.31.0")
else [img_ids_height * img_ids_width, 3],
min_value=0,
max_value=min(img_ids_height, img_ids_width),
framework=framework,
Expand All @@ -1704,7 +1706,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
if input_name == "txt_ids":
import torch

shape = [self.batch_size, self.sequence_length, 3]
shape = (
[self.batch_size, self.sequence_length, 3]
if is_diffusers_version("<", "0.31.0")
else [self.sequence_length, 3]
)
dtype = DTYPE_MAPPER.pt(float_dtype)
return torch.full(shape, 0, dtype=dtype)
return super().generate(input_name, framework, int_dtype, float_dtype)
Expand All @@ -1724,8 +1730,14 @@ def inputs(self):
common_inputs = super().inputs
common_inputs.pop("sample", None)
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["img_ids"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = (
{0: "batch_size", 1: "sequence_length"} if is_diffusers_version("<", "0.31.0") else {0: "sequence_length"}
)
common_inputs["img_ids"] = (
{0: "batch_size", 1: "packed_height_width"}
if is_diffusers_version("<", "0.31.0")
else {0: "packed_height_width"}
)
if getattr(self._normalized_config, "guidance_embeds", False):
common_inputs["guidance"] = {0: "batch_size"}
return common_inputs
Expand Down
10 changes: 6 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_openvino_version,
_torch_version,
_transformers_version,
is_diffusers_version,
is_openvino_version,
is_torch_version,
is_transformers_version,
Expand Down Expand Up @@ -2734,10 +2735,11 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
class FluxTransfromerModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
self._model.pos_embed._orig_forward = self._model.pos_embed.forward
self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed)
if is_diffusers_version("<", "0.31.0"):
self._model.pos_embed._orig_forward = self._model.pos_embed.forward
self._model.pos_embed.forward = types.MethodType(_embednb_forward, self._model.pos_embed)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

self._model.pos_embed.forward = self._model.pos_embed._orig_forward
if hasattr(self._model.pos_embed, "_orig_forward"):
self._model.pos_embed.forward = self._model.pos_embed._orig_forward
8 changes: 6 additions & 2 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,13 @@ def _reshape_transformer(
elif inputs.get_any_name() == "pooled_projections":
shapes[inputs] = [batch_size, self.transformer.config["pooled_projection_dim"]]
elif inputs.get_any_name() == "img_ids":
shapes[inputs] = [batch_size, packed_height_width, 3]
shapes[inputs] = (
[batch_size, packed_height_width, 3]
if is_diffusers_version("<", "0.31.0")
else [packed_height_width, 3]
)
elif inputs.get_any_name() == "txt_ids":
shapes[inputs] = [batch_size, -1, 3]
shapes[inputs] = [batch_size, -1, 3] if is_diffusers_version("<", "0.31.0") else [-1, 3]
else:
shapes[inputs][0] = batch_size
shapes[inputs][1] = -1 # text_encoder_3 may have vary input length
Expand Down

0 comments on commit 52f40af

Please sign in to comment.