Skip to content

Commit

Permalink
fix qwen
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Sep 26, 2023
1 parent 3eb65d6 commit f77fa7f
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 60 deletions.
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc
- [poetry_training](https://github.com/ssbuild/poetry_training)


## dev plan
- 支持 datasets on the way
- 支持 transformer Trainer on the way
- 解耦 lightning on the way

## optimizer
```text
Expand All @@ -45,7 +41,8 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc

## update
- <strong>2023-09-21</strong>
- 0.2.3 支持qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training < 0.2.3
- 0.2.4 支持qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training <= 0.2.3
- support transformers trainer

- <strong>2023-09-21</strong>
- 0.2.3 support dpo 完整训练 [dpo_finetuning](https://github.com/ssbuild/dpo_finetuning)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
]
setup(
name='deep_training',
version='0.2.4rc0',
version='0.2.4',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
Expand Down
12 changes: 3 additions & 9 deletions src/deep_training/data_helper/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset
# from torch.utils.data import DataLoader, IterableDataset
import os
import typing
from typing import Optional, Union
from transformers import PreTrainedTokenizer, PretrainedConfig
from .training_args import ModelArguments, DataArguments, TrainingArguments,TrainingArgumentsHF
Expand Down Expand Up @@ -95,8 +96,7 @@ def load_config(self,
with_labels=True,
with_task_params=True,
return_dict=False,
with_print_labels=True,
with_print_config=True,
with_print_labels=None,
**kwargs):

model_args = self.model_args
Expand Down Expand Up @@ -143,8 +143,6 @@ def load_config(self,
**kwargs_args
)
self.config = config
if with_print_config:
print(config)

if with_labels and self.label2id is not None and hasattr(config, 'num_labels'):
if with_print_labels:
Expand All @@ -164,7 +162,6 @@ def load_tokenizer_and_config(self,
with_task_params=True,
return_dict=False,
with_print_labels=True,
with_print_config=True,
tokenizer_kwargs=None,
config_kwargs=None):

Expand All @@ -175,7 +172,7 @@ def load_tokenizer_and_config(self,
config_kwargs = {}

model_args: ModelArguments = self.model_args
training_args: TrainingArguments = self.training_args
training_args: typing.Optional[TrainingArguments,TrainingArgumentsHF] = self.training_args
data_args: DataArguments = self.data_args


Expand Down Expand Up @@ -234,9 +231,6 @@ def load_tokenizer_and_config(self,
**kwargs_args
)
self.config = config
if with_print_config:
print(config)

if with_labels and self.label2id is not None and hasattr(config, 'num_labels'):
if with_print_labels:
print('==' * 30, 'num_labels = ', config.num_labels)
Expand Down
Loading

0 comments on commit f77fa7f

Please sign in to comment.