diff --git a/README.md b/README.md index 541a4ef1..d1cf8e1b 100644 --- a/README.md +++ b/README.md @@ -23,10 +23,6 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc - [poetry_training](https://github.com/ssbuild/poetry_training) -## dev plan - - 支持 datasets on the way - - 支持 transformer Trainer on the way - - 解耦 lightning on the way ## optimizer ```text @@ -45,7 +41,8 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc ## update - 2023-09-21 - - 0.2.3 支持qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training < 0.2.3 + - 0.2.4 支持qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training <= 0.2.3 + - support transformers trainer - 2023-09-21 - 0.2.3 support dpo 完整训练 [dpo_finetuning](https://github.com/ssbuild/dpo_finetuning) diff --git a/setup.py b/setup.py index 811c705b..6432476d 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ ] setup( name='deep_training', - version='0.2.4rc0', + version='0.2.4', description='an easy training architecture', long_description='torch_training: https://github.com/ssbuild/deep_training.git', license='Apache License 2.0', diff --git a/src/deep_training/data_helper/data_helper.py b/src/deep_training/data_helper/data_helper.py index f1893042..25af9601 100644 --- a/src/deep_training/data_helper/data_helper.py +++ b/src/deep_training/data_helper/data_helper.py @@ -6,6 +6,7 @@ # from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset # from torch.utils.data import DataLoader, IterableDataset import os +import typing from typing import Optional, Union from transformers import PreTrainedTokenizer, PretrainedConfig from .training_args import ModelArguments, DataArguments, TrainingArguments,TrainingArgumentsHF @@ -95,8 +96,7 @@ def load_config(self, with_labels=True, with_task_params=True, return_dict=False, - with_print_labels=True, - with_print_config=True, + with_print_labels=None, **kwargs): model_args = self.model_args @@ -143,8 +143,6 @@ def load_config(self, **kwargs_args ) self.config = config - if with_print_config: - print(config) if with_labels and self.label2id is not None and hasattr(config, 'num_labels'): if with_print_labels: @@ -164,7 +162,6 @@ def load_tokenizer_and_config(self, with_task_params=True, return_dict=False, with_print_labels=True, - with_print_config=True, tokenizer_kwargs=None, config_kwargs=None): @@ -175,7 +172,7 @@ def load_tokenizer_and_config(self, config_kwargs = {} model_args: ModelArguments = self.model_args - training_args: TrainingArguments = self.training_args + training_args: typing.Optional[TrainingArguments,TrainingArgumentsHF] = self.training_args data_args: DataArguments = self.data_args @@ -234,9 +231,6 @@ def load_tokenizer_and_config(self, **kwargs_args ) self.config = config - if with_print_config: - print(config) - if with_labels and self.label2id is not None and hasattr(config, 'num_labels'): if with_print_labels: print('==' * 30, 'num_labels = ', config.num_labels) diff --git a/src/deep_training/nlp/models/qwen/modeling_qwen.py b/src/deep_training/nlp/models/qwen/modeling_qwen.py index 1713c98d..4cf5c4b0 100644 --- a/src/deep_training/nlp/models/qwen/modeling_qwen.py +++ b/src/deep_training/nlp/models/qwen/modeling_qwen.py @@ -29,7 +29,10 @@ from transformers.utils import logging from einops import rearrange from torch import nn - +try: + from kernels.cpp_kernels import cache_autogptq_cuda_256 +except ImportError: + cache_autogptq_cuda_256 = None from .configuration_qwen import QWenConfig from .qwen_generation_utils import ( HistoryType, @@ -117,8 +120,34 @@ def setup_model_profile(skip_init_flag=True): skip_init_function = skip_init else: skip_init_function = default_init - - + + + +def quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + +def dequantize_cache_torch(qdata, scale, zero): + data = scale * (qdata - zero) + return data class FlashSelfAttention(torch.nn.Module): def __init__( @@ -138,7 +167,6 @@ def __init__( self.softmax_scale = softmax_scale self.dropout_p = attention_dropout - def unpad_input(self, hidden_states, attention_mask): valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) @@ -176,7 +204,6 @@ def forward(self, q, k, v, attention_mask=None): q = q[indices_k] cu_seqlens_q = cu_seqlens_k else: - cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, @@ -264,21 +291,52 @@ def __init__(self, config,**kwargs): ] logn_tensor = torch.tensor(logn_list)[None, :, None, None] self.register_buffer("logn_tensor", logn_tensor, persistent=False) - self.attn_dropout = nn.Dropout(config.attn_dropout_prob) + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False + self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + cache_dtype = torch.float + if self.bf16: + cache_dtype=torch.bfloat16 + elif config.fp16: + cache_dtype = torch.float16 + self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) + self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) + device = query.device + if self.use_cache_quantization: + qk, qk_scale, qk_zero = key + if self.use_cache_kernel and cache_autogptq_cuda_256 is not None: + shape = query.shape[:-1] + (qk.shape[-2],) + attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) + cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old( + query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + qk.transpose(-1, -2).contiguous(), + attn_weights, + qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + # attn_weights = attn_weights.to(query.dtype).contiguous() + else: + key = dequantize_cache_torch(qk, qk_scale, qk_zero) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) attn_weights = attn_weights / torch.full( [], - value.size(-1) ** 0.5, + size_temp ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device, ) - - query_length, key_length = query.size(-2), key.size(-2) + if self.use_cache_quantization: + query_length, key_length = query.size(-2), key[0].size(-2) + else: + query_length, key_length = query.size(-2), key.size(-2) causal_mask = registered_causal_mask[ :, :, key_length - query_length : key_length, :key_length ] @@ -289,18 +347,38 @@ def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, attn_weights = torch.where( causal_mask, attn_weights.to(attn_weights.dtype), mask_value ) + if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) - attn_weights = attn_weights.type(value.dtype) + attn_weights = attn_weights.type(query.dtype) attn_weights = self.attn_dropout(attn_weights) if head_mask is not None: attn_weights = attn_weights * head_mask - attn_output = torch.matmul(attn_weights, value) + if self.use_cache_quantization: + qv, qv_scale, qv_zero = value + if self.use_cache_kernel and cache_autogptq_cuda_256 is not None: + shape = attn_weights.shape[:-1] + (query.shape[-1],) + attn_output = torch.zeros(shape, dtype=torch.float16, device=device) + cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old( + attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + qv.contiguous(), # dtype: int32 + attn_output, + qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + if attn_output.dtype != query.dtype: + attn_output = attn_output.to(query.dtype) + attn_weights = attn_weights.to(query.dtype) + else: + value = dequantize_cache_torch(qv, qv_scale, qv_zero) + attn_output = torch.matmul(attn_weights, value) + else: + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -415,10 +493,34 @@ def forward( query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) + if self.use_cache_quantization: + key = quantize_cache_v(key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + value = quantize_cache_v(value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + + if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = (torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2)) + value = (torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2)) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) if use_cache: present = (key, value) @@ -426,10 +528,12 @@ def forward( present = None if self.use_logn_attn and not self.training: - - - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] query = query * logn_tensor.expand_as(query) @@ -447,8 +551,9 @@ def forward( else: query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) if ( registered_causal_mask is None and self.use_flash_attn @@ -523,17 +628,17 @@ def __init__(self, config, **kwargs): self.mlp = init_method(QWenMLP,config,**kwargs) def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - registered_causal_mask: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, ): layernorm_output = self.ln_1(hidden_states) @@ -748,7 +853,10 @@ def forward( past_length = 0 past_key_values = tuple([None] * len(self.h)) else: - past_length = past_key_values[0][0].size(-2) + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange( @@ -777,15 +885,15 @@ def forward( kv_seq_len = hidden_states.size()[1] if past_key_values[0] is not None: # past key values[0][0] shape: bs * seq_len * head_num * dim - kv_seq_len += past_key_values[0][0].shape[1] + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] if self.training or not self.use_dynamic_ntk: ntk_alpha_list = [1.0] elif kv_seq_len != hidden_states.size()[1]: ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list - - - else: ntk_alpha_list = [] if attention_mask is not None and kv_seq_len > self.seq_length: @@ -942,6 +1050,13 @@ def __init__(self, config,**kwargs): if config.use_flash_attn: _import_flash_attn() + if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization: + config.use_flash_attn = False + if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel: + try: + from kernels.cpp_kernels import cache_autogptq_cuda_256 + except ImportError: + cache_autogptq_cuda_256 = None self.transformer = QWenModel(config,**kwargs) self.lm_head = init_method(nn.Linear,config.hidden_size, config.vocab_size, bias=False,**kwargs) if config.bf16: @@ -1142,23 +1257,20 @@ def chat_stream( history: Optional[HistoryType], system: str = "You are a helpful assistant.", generation_config: Optional[GenerationConfig] = None, + stop_words_ids=None, **kwargs ) -> Generator[str, Any, None]: generation_config = generation_config if generation_config is not None else self.generation_config assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT if history is None: history = [] - - stop_words_ids = kwargs.pop('stop_words_ids',[]) - if not isinstance(stop_words_ids,list): - stop_words_ids = [stop_words_ids] - - - + if stop_words_ids is None: + stop_words_ids = [] max_window_size = kwargs.get('max_window_size', None) if max_window_size is None: max_window_size = generation_config.max_window_size logits_processor = kwargs.pop('logits_processor',None) + raw_text, context_tokens = make_context( tokenizer, query, @@ -1250,7 +1362,7 @@ def __init__(self, dim, base=10000,**kwargs): self.dim = dim self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2,**kwargs).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("inv_freq", inv_freq, persistent=False) if importlib.util.find_spec("einops") is None: raise RuntimeError("einops is required for Rotary Embedding") @@ -1301,7 +1413,6 @@ def apply_rotary_pos_emb(t, freqs): if apply_rotary_emb_func is not None and t.is_cuda: t_ = t.float() cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] - sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] output = apply_rotary_emb_func(t_, cos, sin).type_as(t) return output