From 4802635996e970dff773f73fe38d0b93f4157cb2 Mon Sep 17 00:00:00 2001 From: ssbuild <462304@qq.cn> Date: Tue, 26 Sep 2023 17:04:05 +0800 Subject: [PATCH] fix qwen Signed-off-by: ssbuild <462304@qq.cn> --- src/deep_training/nlp/models/qwen/modeling_qwen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/deep_training/nlp/models/qwen/modeling_qwen.py b/src/deep_training/nlp/models/qwen/modeling_qwen.py index 4cf5c4b0..2b880ce1 100644 --- a/src/deep_training/nlp/models/qwen/modeling_qwen.py +++ b/src/deep_training/nlp/models/qwen/modeling_qwen.py @@ -721,6 +721,9 @@ class QWenModel(QWenPreTrainedModel): def __init__(self, config: QWenConfig,**kwargs): super().__init__(config) + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, + 'use_cache_quantization') else False + self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.embed_dim = config.hidden_size @@ -1057,6 +1060,7 @@ def __init__(self, config,**kwargs): from kernels.cpp_kernels import cache_autogptq_cuda_256 except ImportError: cache_autogptq_cuda_256 = None + self.transformer = QWenModel(config,**kwargs) self.lm_head = init_method(nn.Linear,config.hidden_size, config.vocab_size, bias=False,**kwargs) if config.bf16: