Skip to content

Commit

Permalink
fix qwen
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Sep 26, 2023
1 parent f77fa7f commit 4802635
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/deep_training/nlp/models/qwen/modeling_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,9 @@ class QWenModel(QWenPreTrainedModel):

def __init__(self, config: QWenConfig,**kwargs):
super().__init__(config)
self.use_cache_quantization = config.use_cache_quantization if hasattr(config,
'use_cache_quantization') else False

self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
Expand Down Expand Up @@ -1057,6 +1060,7 @@ def __init__(self, config,**kwargs):
from kernels.cpp_kernels import cache_autogptq_cuda_256
except ImportError:
cache_autogptq_cuda_256 = None

self.transformer = QWenModel(config,**kwargs)
self.lm_head = init_method(nn.Linear,config.hidden_size, config.vocab_size, bias=False,**kwargs)
if config.bf16:
Expand Down

0 comments on commit 4802635

Please sign in to comment.