diff --git a/README.md b/README.md
index 541a4ef1..d1cf8e1b 100644
--- a/README.md
+++ b/README.md
@@ -23,10 +23,6 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc
- [poetry_training](https://github.com/ssbuild/poetry_training)
-## dev plan
- - 支持 datasets on the way
- - 支持 transformer Trainer on the way
- - 解耦 lightning on the way
## optimizer
```text
@@ -45,7 +41,8 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc
## update
- 2023-09-21
- - 0.2.3 支持qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training < 0.2.3
+ - 0.2.4 支持qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training <= 0.2.3
+ - support transformers trainer
- 2023-09-21
- 0.2.3 support dpo 完整训练 [dpo_finetuning](https://github.com/ssbuild/dpo_finetuning)
diff --git a/setup.py b/setup.py
index 811c705b..6432476d 100644
--- a/setup.py
+++ b/setup.py
@@ -20,7 +20,7 @@
]
setup(
name='deep_training',
- version='0.2.4rc0',
+ version='0.2.4',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
diff --git a/src/deep_training/data_helper/data_helper.py b/src/deep_training/data_helper/data_helper.py
index f1893042..25af9601 100644
--- a/src/deep_training/data_helper/data_helper.py
+++ b/src/deep_training/data_helper/data_helper.py
@@ -6,6 +6,7 @@
# from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset
# from torch.utils.data import DataLoader, IterableDataset
import os
+import typing
from typing import Optional, Union
from transformers import PreTrainedTokenizer, PretrainedConfig
from .training_args import ModelArguments, DataArguments, TrainingArguments,TrainingArgumentsHF
@@ -95,8 +96,7 @@ def load_config(self,
with_labels=True,
with_task_params=True,
return_dict=False,
- with_print_labels=True,
- with_print_config=True,
+ with_print_labels=None,
**kwargs):
model_args = self.model_args
@@ -143,8 +143,6 @@ def load_config(self,
**kwargs_args
)
self.config = config
- if with_print_config:
- print(config)
if with_labels and self.label2id is not None and hasattr(config, 'num_labels'):
if with_print_labels:
@@ -164,7 +162,6 @@ def load_tokenizer_and_config(self,
with_task_params=True,
return_dict=False,
with_print_labels=True,
- with_print_config=True,
tokenizer_kwargs=None,
config_kwargs=None):
@@ -175,7 +172,7 @@ def load_tokenizer_and_config(self,
config_kwargs = {}
model_args: ModelArguments = self.model_args
- training_args: TrainingArguments = self.training_args
+ training_args: typing.Optional[TrainingArguments,TrainingArgumentsHF] = self.training_args
data_args: DataArguments = self.data_args
@@ -234,9 +231,6 @@ def load_tokenizer_and_config(self,
**kwargs_args
)
self.config = config
- if with_print_config:
- print(config)
-
if with_labels and self.label2id is not None and hasattr(config, 'num_labels'):
if with_print_labels:
print('==' * 30, 'num_labels = ', config.num_labels)
diff --git a/src/deep_training/nlp/models/qwen/modeling_qwen.py b/src/deep_training/nlp/models/qwen/modeling_qwen.py
index 1713c98d..4cf5c4b0 100644
--- a/src/deep_training/nlp/models/qwen/modeling_qwen.py
+++ b/src/deep_training/nlp/models/qwen/modeling_qwen.py
@@ -29,7 +29,10 @@
from transformers.utils import logging
from einops import rearrange
from torch import nn
-
+try:
+ from kernels.cpp_kernels import cache_autogptq_cuda_256
+except ImportError:
+ cache_autogptq_cuda_256 = None
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
HistoryType,
@@ -117,8 +120,34 @@ def setup_model_profile(skip_init_flag=True):
skip_init_function = skip_init
else:
skip_init_function = default_init
-
-
+
+
+
+def quantize_cache_v(fdata, bits, qmax, qmin):
+ # b, s, head, h-dim->b, head, s, h-dim
+ qtype = torch.uint8
+ device = fdata.device
+ shape = fdata.shape
+
+ fdata_cal = torch.flatten(fdata, 2)
+ fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
+ fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
+ # Compute params
+ if qmax.device != fmax.device:
+ qmax = qmax.to(device)
+ qmin = qmin.to(device)
+ scale = (fmax - fmin) / (qmax - qmin)
+ zero = qmin - fmin / scale
+ scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
+ zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
+ # Quantize
+ res_data = fdata / scale + zero
+ qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
+ return qdata.contiguous(), scale, zero
+
+def dequantize_cache_torch(qdata, scale, zero):
+ data = scale * (qdata - zero)
+ return data
class FlashSelfAttention(torch.nn.Module):
def __init__(
@@ -138,7 +167,6 @@ def __init__(
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
-
def unpad_input(self, hidden_states, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
@@ -176,7 +204,6 @@ def forward(self, q, k, v, attention_mask=None):
q = q[indices_k]
cu_seqlens_q = cu_seqlens_k
else:
-
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
@@ -264,21 +291,52 @@ def __init__(self, config,**kwargs):
]
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
-
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
+ self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
+ self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
+ cache_dtype = torch.float
+ if self.bf16:
+ cache_dtype=torch.bfloat16
+ elif config.fp16:
+ cache_dtype = torch.float16
+ self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
+ self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
+ device = query.device
+ if self.use_cache_quantization:
+ qk, qk_scale, qk_zero = key
+ if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
+ shape = query.shape[:-1] + (qk.shape[-2],)
+ attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
+ cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
+ query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
+ qk.transpose(-1, -2).contiguous(),
+ attn_weights,
+ qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
+ qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
+ # attn_weights = attn_weights.to(query.dtype).contiguous()
+ else:
+ key = dequantize_cache_torch(qk, qk_scale, qk_zero)
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+ else:
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
+ if self.use_cache_quantization:
+ size_temp = value[0].size(-1)
+ else:
+ size_temp = value.size(-1)
attn_weights = attn_weights / torch.full(
[],
- value.size(-1) ** 0.5,
+ size_temp ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
-
- query_length, key_length = query.size(-2), key.size(-2)
+ if self.use_cache_quantization:
+ query_length, key_length = query.size(-2), key[0].size(-2)
+ else:
+ query_length, key_length = query.size(-2), key.size(-2)
causal_mask = registered_causal_mask[
:, :, key_length - query_length : key_length, :key_length
]
@@ -289,18 +347,38 @@ def _attn(self, query, key, value, registered_causal_mask, attention_mask=None,
attn_weights = torch.where(
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
)
+
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
- attn_weights = attn_weights.type(value.dtype)
+ attn_weights = attn_weights.type(query.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
- attn_output = torch.matmul(attn_weights, value)
+ if self.use_cache_quantization:
+ qv, qv_scale, qv_zero = value
+ if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
+ shape = attn_weights.shape[:-1] + (query.shape[-1],)
+ attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
+ cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
+ attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
+ qv.contiguous(), # dtype: int32
+ attn_output,
+ qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
+ qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
+ if attn_output.dtype != query.dtype:
+ attn_output = attn_output.to(query.dtype)
+ attn_weights = attn_weights.to(query.dtype)
+ else:
+ value = dequantize_cache_torch(qv, qv_scale, qv_zero)
+ attn_output = torch.matmul(attn_weights, value)
+ else:
+ attn_output = torch.matmul(attn_weights, value)
+
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
@@ -415,10 +493,34 @@ def forward(
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
+ if self.use_cache_quantization:
+ key = quantize_cache_v(key.permute(0, 2, 1, 3),
+ bits=8,
+ qmin=self.cache_qmin,
+ qmax=self.cache_qmax)
+ value = quantize_cache_v(value.permute(0, 2, 1, 3),
+ bits=8,
+ qmin=self.cache_qmin,
+ qmax=self.cache_qmax)
+
+
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
- key = torch.cat((past_key, key), dim=1)
- value = torch.cat((past_value, value), dim=1)
+ if self.use_cache_quantization:
+ # use_cache_quantization:
+ # present=((q_key,key_scale,key_zero_point),
+ # (q_value,value_scale,value_zero_point))
+ key = (torch.cat((past_key[0], key[0]), dim=2),
+ torch.cat((past_key[1], key[1]), dim=2),
+ torch.cat((past_key[2], key[2]), dim=2))
+ value = (torch.cat((past_value[0], value[0]), dim=2),
+ torch.cat((past_value[1], value[1]), dim=2),
+ torch.cat((past_value[2], value[2]), dim=2))
+ else:
+ # not use_cache_quantization:
+ # present=(key,value)
+ key = torch.cat((past_key, key), dim=1)
+ value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
@@ -426,10 +528,12 @@ def forward(
present = None
if self.use_logn_attn and not self.training:
-
-
- seq_start = key.size(1) - query.size(1)
- seq_end = key.size(1)
+ if self.use_cache_quantization:
+ seq_start = key[0].size(2) - query.size(1)
+ seq_end = key[0].size(2)
+ else:
+ seq_start = key.size(1) - query.size(1)
+ seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)
@@ -447,8 +551,9 @@ def forward(
else:
query = query.permute(0, 2, 1, 3)
- key = key.permute(0, 2, 1, 3)
- value = value.permute(0, 2, 1, 3)
+ if not self.use_cache_quantization:
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
if (
registered_causal_mask is None
and self.use_flash_attn
@@ -523,17 +628,17 @@ def __init__(self, config, **kwargs):
self.mlp = init_method(QWenMLP,config,**kwargs)
def forward(
- self,
- hidden_states: Optional[Tuple[torch.FloatTensor]],
- rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
- registered_causal_mask: Optional[torch.Tensor] = None,
- layer_past: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = False,
- output_attentions: Optional[bool] = False,
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
+ registered_causal_mask: Optional[torch.Tensor] = None,
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
):
layernorm_output = self.ln_1(hidden_states)
@@ -748,7 +853,10 @@ def forward(
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
- past_length = past_key_values[0][0].size(-2)
+ if self.use_cache_quantization:
+ past_length = past_key_values[0][0][0].size(2)
+ else:
+ past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
@@ -777,15 +885,15 @@ def forward(
kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
- kv_seq_len += past_key_values[0][0].shape[1]
+ if self.use_cache_quantization:
+ kv_seq_len += past_key_values[0][0][0].shape[2]
+ else:
+ kv_seq_len += past_key_values[0][0].shape[1]
if self.training or not self.use_dynamic_ntk:
ntk_alpha_list = [1.0]
elif kv_seq_len != hidden_states.size()[1]:
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
-
-
-
else:
ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length:
@@ -942,6 +1050,13 @@ def __init__(self, config,**kwargs):
if config.use_flash_attn:
_import_flash_attn()
+ if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
+ config.use_flash_attn = False
+ if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
+ try:
+ from kernels.cpp_kernels import cache_autogptq_cuda_256
+ except ImportError:
+ cache_autogptq_cuda_256 = None
self.transformer = QWenModel(config,**kwargs)
self.lm_head = init_method(nn.Linear,config.hidden_size, config.vocab_size, bias=False,**kwargs)
if config.bf16:
@@ -1142,23 +1257,20 @@ def chat_stream(
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
generation_config: Optional[GenerationConfig] = None,
+ stop_words_ids=None,
**kwargs
) -> Generator[str, Any, None]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
-
- stop_words_ids = kwargs.pop('stop_words_ids',[])
- if not isinstance(stop_words_ids,list):
- stop_words_ids = [stop_words_ids]
-
-
-
+ if stop_words_ids is None:
+ stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
logits_processor = kwargs.pop('logits_processor',None)
+
raw_text, context_tokens = make_context(
tokenizer,
query,
@@ -1250,7 +1362,7 @@ def __init__(self, dim, base=10000,**kwargs):
self.dim = dim
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2,**kwargs).float() / dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
if importlib.util.find_spec("einops") is None:
raise RuntimeError("einops is required for Rotary Embedding")
@@ -1301,7 +1413,6 @@ def apply_rotary_pos_emb(t, freqs):
if apply_rotary_emb_func is not None and t.is_cuda:
t_ = t.float()
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
-
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output