Skip to content

Commit

Permalink
[PT FE] Improve 16bit patching (#27693)
Browse files Browse the repository at this point in the history
### Details:
 - *Cherry-pick #27428 and #27413 in 24.6 branch*

### Tickets:
 - *ticket-id*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored Nov 27, 2024
1 parent 45779c2 commit 1c9b23c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 51 deletions.
83 changes: 32 additions & 51 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,15 @@
# flake8: noqa
# mypy: ignore-errors

import functools
import logging
import torch
from openvino.frontend.pytorch import ModuleExtension

log = logging.getLogger(__name__)


class no_jit_trace:
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)

def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None


def patch_model(model, module_extensions, orig_forward_name, use_meta=False):
def patch_model(model, module_extensions, orig_forward_name):
def module_patcher(m, name):
extension = None
if m in module_extensions:
Expand All @@ -33,43 +24,31 @@ def module_patcher(m, name):

if extension:
log.debug("Patching module %s", m)
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
# The Trampoline class is instantiated for every module replacement, so we can use
# class members individually for each module.

class Trampoline(torch.autograd.Function):
# required to be saved in class
target_extension = extension
original_module = m
stashed_args = tuple()
stashed_kwargs = {}

@staticmethod
@torch.jit.ignore
def forward(*args, **kwargs):
with no_jit_trace():
# `module` is going to be passed to a user-defined function `evaluate`
# `module` is patched: forward function was replaced, and we are actually in this patched function right in this code
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail
# so we need to temporary patch the module back to the original forward and then return it back again
# stash the current forward to be able to return it back
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(m, *Trampoline.stashed_args,
**Trampoline.stashed_kwargs)
m.forward = patched_forward # return patched forward back
return results
def forward(ctx, *args, **kwargs):
# Temporarily restore the original forward function of `module` to avoid
# recursion issues in `evaluate`, then revert it back.
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(m, *args, **kwargs)
m.forward = patched_forward # return patched forward back
return results

def new_forward(*args, **kwargs):
# use meta device to store args, to save memory
if use_meta:
d = torch.device("meta")
Trampoline.stashed_args = tuple(a.to(d) for a in args)
Trampoline.stashed_kwargs = dict((k, v.to(d)) for k, v in kwargs.items())
else:
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)

# make signature of new_forward same as of forward
new_forward = functools.wraps(m.forward)(new_forward)
setattr(m, orig_forward_name, m.forward)
m.forward = new_forward

Expand Down Expand Up @@ -106,36 +85,38 @@ def __make_16bit_traceable(model: torch.nn.Module):
extensions = {
torch.nn.Linear: ModuleExtension(
torch.nn.Linear, "ov_ext::linear",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias)),
module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32)),
torch.nn.Embedding: ModuleExtension(
torch.nn.Embedding, "ov_ext::embedding",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight,
args[0],
module.padding_idx,
module.scale_grad_by_freq,
module.sparse)),
module.sparse),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[1].shape) + [module.embedding_dim], 0.5, dtype=torch.float32)),
}
try:
from transformers.pytorch_utils import Conv1D
extensions[Conv1D] = ModuleExtension(
Conv1D, "ov_ext::conv1d",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias))
except:
module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32))
except ImportError:
pass
patch_model(model, extensions,
"_openvino_module_extension_patch_orig_forward", use_meta=True)
"_openvino_module_extension_patch_orig_forward")
dtype_to_patch = [torch.float16, torch.bfloat16]
for _, module in model.named_modules():
if module.__class__ not in extensions and (any(p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False))
or any(b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False))):
if (module.__class__ not in extensions and
(any(p.dtype in dtype_to_patch for p in module.parameters(False))
or any(b.dtype in dtype_to_patch for b in module.buffers(False)))):
log.debug("Casting module %s to float32", module)
module.float()
4 changes: 4 additions & 0 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def test_patched_16bit_model_converts():
from openvino.frontend.pytorch import patch_model
from openvino import convert_model, compile_model
import copy
import inspect
from transformers.pytorch_utils import Conv1D

class ModelWithLinear(torch.nn.Module):
Expand Down Expand Up @@ -716,6 +717,9 @@ def forward(self, x1, x2):
model_fp16 = copy.deepcopy(model_ref).half()

patch_model.__make_16bit_traceable(model_fp16)
# verify torch.nn.Linear signature after patching
signature = inspect.signature(model_ref.branch1[0].forward).parameters
assert ["input"] == list(signature)
# the approach with patching only works for node with no grad
with torch.no_grad():
converted_model = convert_model(model_fp16, example_input=example)
Expand Down

0 comments on commit 1c9b23c

Please sign in to comment.