From c4ccae92262e62f0a1f52f8c1d3675836f3c3d56 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 20 May 2024 06:15:52 +0000 Subject: [PATCH] wrap in parameters and torch view to correct dtype Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_autogptq.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 2fd2f1e9..dc92f401 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -25,6 +25,7 @@ from fms_acceleration import AccelerationPlugin from peft import LoraConfig, prepare_model_for_kbit_training from peft.tuners.lora.model import LoraModel +import torch.distributed from transformers import AutoModelForCausalLM, TrainingArguments import torch @@ -121,6 +122,30 @@ def model_loader(self, model_name: str, **kwargs): device_map=device_map, ) + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures,) + quant_linear_fn = QuantLinearFunction + + out = quant_linear_fn.apply( + x.reshape(-1, x.shape[-1]), + self.qweight.view(torch.int32), + self.scales, + self.qzeros.view(torch.int32), + self.g_idx, + self.bits, + self.maxq, + ) + out = out.half().reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + + for mod in model.modules(): + if isinstance(mod, QuantLinear): + mod.qweight = torch.nn.Parameter(mod.qweight.view(torch_dtype), requires_grad=False) + mod.qzeros = torch.nn.Parameter(mod.qzeros.view(torch_dtype), requires_grad=False) + mod.forward = MethodType(forward, mod) + # replace AutoModelForCausalLM.from_config = _old_from_config