Skip to content

Commit

Permalink
fix qwen quantization
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Aug 4, 2023
1 parent e581af3 commit 659b177
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/deep_training/nlp/models/baichuan/modeling_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def __init__(self, config: BaiChuanConfig,**kwargs):
self.post_init()

self.quantized = False
if self.config.quantization_bit is not None and self.config.quantization_bit not in [0, 32]:
if self.config.quantization_bit in [4,8]:
self.quantize(self.config.quantization_bit, empty_init=True)

def get_input_embeddings(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __init__(self, config,**kwargs):
self.post_init()

self.quantized = False
if self.config.quantization_bit is not None and self.config.quantization_bit not in [0,32]:
if self.config.quantization_bit in [4,8]:
self.quantize(self.config.quantization_bit,empty_init=True)


Expand Down
2 changes: 1 addition & 1 deletion src/deep_training/nlp/models/chatglm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def __init__(self, config: ChatGLMConfig):
self.config = config
self.quantized = False

if self.config.quantization_bit:
if self.config.quantization_bit in [4,8]:
self.quantize(self.config.quantization_bit, empty_init=True,dtype=self.transformer.params_dtype or torch.half)

def get_output_embeddings(self):
Expand Down
2 changes: 1 addition & 1 deletion src/deep_training/nlp/models/chatglm2/modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def __init__(self, config: ChatGLMConfig,device=None):
self.config = config

self.quantized = False
if self.config.quantization_bit:
if self.config.quantization_bit in [4,8]:
self.quantize(self.config.quantization_bit, empty_init=True)

def _update_model_kwargs_for_generation(
Expand Down
4 changes: 2 additions & 2 deletions src/deep_training/nlp/models/internlm/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def quantize(model, bits, empty_init=False, device=None,**kwarg):
QuantizedLinear(
bits=bits,
weight=w.weight.to(torch.cuda.current_device()),
bias=None,
bias=w.bias.to(torch.cuda.current_device()),
empty_init=empty_init,
device=w.weight.device if device is None else device,
dtype=w.weight.dtype,
Expand All @@ -176,7 +176,7 @@ def quantize(model, bits, empty_init=False, device=None,**kwarg):
QuantizedLinear(
bits=bits,
weight=w.weight.to(torch.cuda.current_device()),
bias=None,
bias=w.bias.to(torch.cuda.current_device()),
empty_init=empty_init,
device=w.weight.device if device is None else device,
dtype=w.weight.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/deep_training/nlp/models/qwen/modeling_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def __init__(self, config,**kwargs):
self.post_init()

self.quantized = False
if self.config.quantization_bit is not None and self.config.quantization_bit not in [0, 32]:
if self.config.quantization_bit in [4,8]:
self.quantize(self.config.quantization_bit, empty_init=True)

def get_output_embeddings(self):
Expand Down
4 changes: 2 additions & 2 deletions src/deep_training/nlp/models/qwen/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def quantize(model, bits, empty_init=False, device=None,**kwarg):
QuantizedLinear(
bits=bits,
weight=w.weight.to(torch.cuda.current_device()),
bias=None,
bias=w.bias.to(torch.cuda.current_device()),
empty_init=empty_init,
device=w.weight.device if device is None else device,
dtype=w.weight.dtype,
Expand All @@ -175,7 +175,7 @@ def quantize(model, bits, empty_init=False, device=None,**kwarg):
QuantizedLinear(
bits=bits,
weight=w.weight.to(torch.cuda.current_device()),
bias=None,
bias=w.bias.to(torch.cuda.current_device()),
empty_init=empty_init,
device=w.weight.device if device is None else device,
dtype=w.weight.dtype,
Expand Down

0 comments on commit 659b177

Please sign in to comment.