diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index 49072208..b02c56e5 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -18,7 +18,7 @@ # Third Party from peft import LoraConfig from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ -from transformers.utils.import_utils import _is_package_available +from typing import List, Callable import torch @@ -54,3 +54,24 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module + +# consider to move this somewhere more general +def patch_forward_to_view_attributes_before_call( + old_forward: Callable, + attribute_names: List[str], torch_dtype, +): + # patch old_forward to view attribtues to torch_dype + # before call + + def _forward(self, *args, **kwargs): + # perform a view on all these attributes + for attr_name in attribute_names: + + # the view should be a passthrough + # if attr.dtype == torch_dtype + attr = getattr(self, attr_name) + attr = attr.view(torch_dtype) + setattr(self, attr_name, attr) + + return old_forward(*args, **kwargs) + return _forward diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index dc92f401..fa6082ab 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -28,6 +28,7 @@ import torch.distributed from transformers import AutoModelForCausalLM, TrainingArguments import torch +import os class AutoGPTQAccelerationPlugin(AccelerationPlugin): @@ -51,6 +52,8 @@ def model_loader(self, model_name: str, **kwargs): # guarded imports # Third Party from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction + from .autogptq_utils import patch_forward_to_view_attributes_before_call # Currently we allow only a quantized checkpoint to be loaded, we do not # implement the quantization process here. @@ -122,29 +125,42 @@ def model_loader(self, model_name: str, **kwargs): device_map=device_map, ) - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - quant_linear_fn = QuantLinearFunction - - out = quant_linear_fn.apply( - x.reshape(-1, x.shape[-1]), - self.qweight.view(torch.int32), - self.scales, - self.qzeros.view(torch.int32), - self.g_idx, - self.bits, - self.maxq, - ) - out = out.half().reshape(out_shape) - out = out + self.bias if self.bias is not None else out - return out - - for mod in model.modules(): - if isinstance(mod, QuantLinear): - mod.qweight = torch.nn.Parameter(mod.qweight.view(torch_dtype), requires_grad=False) - mod.qzeros = torch.nn.Parameter(mod.qzeros.view(torch_dtype), requires_grad=False) - mod.forward = MethodType(forward, mod) + # https://github.com/foundation-model-stack/fms-acceleration/pull/15 + # if FSDP distributed need to convert the AutoGPTQ model's + # parameters (in tensors) to parameters. Also need to + # store the int32 tensors in a float type + + try: + world_size = torch.distributed.get_world_size() + except ValueError: + world_size = 1 # pg not init + + if ( + world_size > 1 + and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" + ): + # these parameters are to be patched for triton v2 + # consider making a map if patching more kernels + PATCH_FOR_FSDP_TRITON_V2 = ['qweight', 'qzeros'] + + # patch all the QuantLinear base layers + for mod in model.modules(): + if isinstance(mod, QuantLinear): + + # convert all patched attributes to Parameters of torch_dtype + # so FSDP can shard them + for attr_name in PATCH_FOR_FSDP_TRITON_V2: + attr = getattr(mod, attr_name) + attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False) + setattr(mod, attr_name, attr) + + # this patches the forward to convert them back to original + # type (i.e. int32) before the function call into the kernels + _forward = patch_forward_to_view_attributes_before_call( + mod.forward, attribute_names=PATCH_FOR_FSDP_TRITON_V2, + torch_dtype=torch.int32, # patch it back to + ) + mod.forward = MethodType(_forward, mod) # replace AutoModelForCausalLM.from_config = _old_from_config