Skip to content

Commit

Permalink
refactor to apply patch only on FSDP and simplify
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed May 20, 2024
1 parent 33eb88e commit c0e449a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Third Party
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
from transformers.utils.import_utils import _is_package_available
from typing import List, Callable
import torch


Expand Down Expand Up @@ -54,3 +54,24 @@ def create_new_module_peft(

# if module cannot be found, return None which results in a raise in the call-stack
return new_module

# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
attribute_names: List[str], torch_dtype,
):
# patch old_forward to view attribtues to torch_dype
# before call

def _forward(self, *args, **kwargs):
# perform a view on all these attributes
for attr_name in attribute_names:

# the view should be a passthrough
# if attr.dtype == torch_dtype
attr = getattr(self, attr_name)
attr = attr.view(torch_dtype)
setattr(self, attr_name, attr)

return old_forward(*args, **kwargs)
return _forward
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch.distributed
from transformers import AutoModelForCausalLM, TrainingArguments
import torch
import os


class AutoGPTQAccelerationPlugin(AccelerationPlugin):
Expand All @@ -51,6 +52,8 @@ def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction
from .autogptq_utils import patch_forward_to_view_attributes_before_call

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -122,29 +125,42 @@ def model_loader(self, model_name: str, **kwargs):
device_map=device_map,
)

from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
quant_linear_fn = QuantLinearFunction

out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight.view(torch.int32),
self.scales,
self.qzeros.view(torch.int32),
self.g_idx,
self.bits,
self.maxq,
)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

for mod in model.modules():
if isinstance(mod, QuantLinear):
mod.qweight = torch.nn.Parameter(mod.qweight.view(torch_dtype), requires_grad=False)
mod.qzeros = torch.nn.Parameter(mod.qzeros.view(torch_dtype), requires_grad=False)
mod.forward = MethodType(forward, mod)
# https://github.com/foundation-model-stack/fms-acceleration/pull/15
# if FSDP distributed need to convert the AutoGPTQ model's
# parameters (in tensors) to parameters. Also need to
# store the int32 tensors in a float type

try:
world_size = torch.distributed.get_world_size()
except ValueError:
world_size = 1 # pg not init

if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
# these parameters are to be patched for triton v2
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ['qweight', 'qzeros']

# patch all the QuantLinear base layers
for mod in model.modules():
if isinstance(mod, QuantLinear):

# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(mod, attr_name)
attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False)
setattr(mod, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
_forward = patch_forward_to_view_attributes_before_call(
mod.forward, attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)
mod.forward = MethodType(_forward, mod)

# replace
AutoModelForCausalLM.from_config = _old_from_config
Expand Down

0 comments on commit c0e449a

Please sign in to comment.