Skip to content

Commit

Permalink
wrap in parameters and torch view to correct dtype
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed May 20, 2024
1 parent 4fbc88b commit c4ccae9
Showing 1 changed file with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
import torch.distributed
from transformers import AutoModelForCausalLM, TrainingArguments
import torch

Expand Down Expand Up @@ -121,6 +122,30 @@ def model_loader(self, model_name: str, **kwargs):
device_map=device_map,
)

from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
quant_linear_fn = QuantLinearFunction

out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight.view(torch.int32),
self.scales,
self.qzeros.view(torch.int32),
self.g_idx,
self.bits,
self.maxq,
)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

for mod in model.modules():
if isinstance(mod, QuantLinear):
mod.qweight = torch.nn.Parameter(mod.qweight.view(torch_dtype), requires_grad=False)
mod.qzeros = torch.nn.Parameter(mod.qzeros.view(torch_dtype), requires_grad=False)
mod.forward = MethodType(forward, mod)

# replace
AutoModelForCausalLM.from_config = _old_from_config

Expand Down

0 comments on commit c4ccae9

Please sign in to comment.