From 1c9b23ccca1cca72e3f11f8cdad2e88fc7f7e633 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 27 Nov 2024 12:37:10 +0100 Subject: [PATCH] [PT FE] Improve 16bit patching (#27693) ### Details: - *Cherry-pick #27428 and #27413 in 24.6 branch* ### Tickets: - *ticket-id* --------- Signed-off-by: Maxim Vafin --- .../openvino/frontend/pytorch/patch_model.py | 83 +++++++------------ .../py_frontend_tests/test_torch_frontend.py | 4 + 2 files changed, 36 insertions(+), 51 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py b/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py index fb8f70e2a566bc..55001180cba3fb 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/patch_model.py @@ -4,6 +4,7 @@ # flake8: noqa # mypy: ignore-errors +import functools import logging import torch from openvino.frontend.pytorch import ModuleExtension @@ -11,17 +12,7 @@ log = logging.getLogger(__name__) -class no_jit_trace: - def __enter__(self): - self.state = torch._C._get_tracing_state() - torch._C._set_tracing_state(None) - - def __exit__(self, *args): - torch._C._set_tracing_state(self.state) - self.state = None - - -def patch_model(model, module_extensions, orig_forward_name, use_meta=False): +def patch_model(model, module_extensions, orig_forward_name): def module_patcher(m, name): extension = None if m in module_extensions: @@ -33,43 +24,31 @@ def module_patcher(m, name): if extension: log.debug("Patching module %s", m) - # The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module. + # The Trampoline class is instantiated for every module replacement, so we can use + # class members individually for each module. class Trampoline(torch.autograd.Function): + # required to be saved in class target_extension = extension - original_module = m - stashed_args = tuple() - stashed_kwargs = {} @staticmethod @torch.jit.ignore - def forward(*args, **kwargs): - with no_jit_trace(): - # `module` is going to be passed to a user-defined function `evaluate` - # `module` is patched: forward function was replaced, and we are actually in this patched function right in this code - # if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail - # so we need to temporary patch the module back to the original forward and then return it back again - # stash the current forward to be able to return it back - patched_forward = m.forward - # set original forward for the module - m.forward = getattr(m, orig_forward_name) - # call user code - results = extension.evaluate(m, *Trampoline.stashed_args, - **Trampoline.stashed_kwargs) - m.forward = patched_forward # return patched forward back - return results + def forward(ctx, *args, **kwargs): + # Temporarily restore the original forward function of `module` to avoid + # recursion issues in `evaluate`, then revert it back. + patched_forward = m.forward + # set original forward for the module + m.forward = getattr(m, orig_forward_name) + # call user code + results = extension.evaluate(m, *args, **kwargs) + m.forward = patched_forward # return patched forward back + return results def new_forward(*args, **kwargs): - # use meta device to store args, to save memory - if use_meta: - d = torch.device("meta") - Trampoline.stashed_args = tuple(a.to(d) for a in args) - Trampoline.stashed_kwargs = dict((k, v.to(d)) for k, v in kwargs.items()) - else: - Trampoline.stashed_args = args - Trampoline.stashed_kwargs = kwargs return extension.convert(m, Trampoline.apply, *args, **kwargs) + # make signature of new_forward same as of forward + new_forward = functools.wraps(m.forward)(new_forward) setattr(m, orig_forward_name, m.forward) m.forward = new_forward @@ -106,36 +85,38 @@ def __make_16bit_traceable(model: torch.nn.Module): extensions = { torch.nn.Linear: ModuleExtension( torch.nn.Linear, "ov_ext::linear", - evaluate=lambda module, *args, **kwargs: torch.full( - list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32), convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, - module.bias)), + module.bias), + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32)), torch.nn.Embedding: ModuleExtension( torch.nn.Embedding, "ov_ext::embedding", - evaluate=lambda module, *args, **kwargs: torch.full( - list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32), convert=lambda module, target_op, *args, **kwargs: target_op(module.weight, args[0], module.padding_idx, module.scale_grad_by_freq, - module.sparse)), + module.sparse), + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[1].shape) + [module.embedding_dim], 0.5, dtype=torch.float32)), } try: from transformers.pytorch_utils import Conv1D extensions[Conv1D] = ModuleExtension( Conv1D, "ov_ext::conv1d", - evaluate=lambda module, *args, **kwargs: torch.full( - list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32), convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, - module.bias)) - except: + module.bias), + evaluate=lambda module, *args, **kwargs: torch.full( + list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32)) + except ImportError: pass patch_model(model, extensions, - "_openvino_module_extension_patch_orig_forward", use_meta=True) + "_openvino_module_extension_patch_orig_forward") + dtype_to_patch = [torch.float16, torch.bfloat16] for _, module in model.named_modules(): - if module.__class__ not in extensions and (any(p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False)) - or any(b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False))): + if (module.__class__ not in extensions and + (any(p.dtype in dtype_to_patch for p in module.parameters(False)) + or any(b.dtype in dtype_to_patch for b in module.buffers(False)))): log.debug("Casting module %s to float32", module) module.float() diff --git a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py index faee72bb5d938a..b659c1735d8453 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py @@ -687,6 +687,7 @@ def test_patched_16bit_model_converts(): from openvino.frontend.pytorch import patch_model from openvino import convert_model, compile_model import copy + import inspect from transformers.pytorch_utils import Conv1D class ModelWithLinear(torch.nn.Module): @@ -716,6 +717,9 @@ def forward(self, x1, x2): model_fp16 = copy.deepcopy(model_ref).half() patch_model.__make_16bit_traceable(model_fp16) + # verify torch.nn.Linear signature after patching + signature = inspect.signature(model_ref.branch1[0].forward).parameters + assert ["input"] == list(signature) # the approach with patching only works for node with no grad with torch.no_grad(): converted_model = convert_model(model_fp16, example_input=example)