Skip to content

Commit

Permalink
Merge branch 'main' into baichuan_devices
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan authored Sep 20, 2024
2 parents df09c7d + 169be08 commit f95734e
Show file tree
Hide file tree
Showing 12 changed files with 1,679 additions and 2 deletions.
6 changes: 5 additions & 1 deletion libai/tokenizer/tokenization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,11 @@ def encode(self, text, return_tensors=None, is_global=False, device="cuda", **kw
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
token_ids_list = self.convert_to_tensors(
token_ids_list, return_tensors=return_tensors, is_global=is_global, **kwargs
token_ids_list,
return_tensors=return_tensors,
is_global=is_global,
device=device,
**kwargs,
)
return token_ids_list
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
Expand Down
67 changes: 67 additions & 0 deletions projects/Qwen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

### 推理

- cuda PASS

```bash
python projects/Qwen/pipeline.py --model_path=/root/models/Qwen1.5-7B-Chat --mode=huggingface
```

- npu PASS

```bash
python projects/Qwen/pipeline.py --model_path=/data0/hf_models/qwen2/Qwen1.5-7B-Chat --mode=huggingface --device=npu
```

- xpu PASS

```bash
python projects/Qwen/pipeline.py --model_path=/root/models/Qwen1.5-7B-Chat --mode=huggingface --device=xpu
```

### 训练

- data preparation

```bash
python projects/Qwen/utils/data_prepare.py
```

- cuda PASS

```bash
export NUM_GPUS=8
python3 -m oneflow.distributed.launch \
--nproc_per_node ${NUM_GPUS} \
--nnodes 1 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 12345 \
tools/train_net.py --config-file=projects/Qwen/configs/qwen_sft.py \
graph.enabled=True \
train.input_placement_device="cuda" \
train.dist.device_type="cuda" \
train.dist.pipeline_parallel_size=${NUM_GPUS}
```
A100-PCIE-40GB x 4 OOM

- xpu OOM

```bash
export NUM_GPUS=1
python3 -m oneflow.distributed.launch \
--nproc_per_node ${NUM_GPUS} \
--nnodes 1 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 12345 \
tools/train_net.py --config-file=projects/Qwen/configs/qwen_sft.py \
graph.enabled=False \
train.input_placement_device="xpu" \
train.dist.device_type="xpu" \
train.dist.pipeline_parallel_size=${NUM_GPUS}
```

- npu 没有测,应该不行


62 changes: 62 additions & 0 deletions projects/Qwen/configs/qwen_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from omegaconf import DictConfig, OmegaConf

from configs.common.train import train
from libai.config import LazyCall
from projects.Qwen.qwen2 import Qwen2ForCausalLM
from projects.Qwen.tokenizer import Qwen2Tokenizer

cfg = dict(
# Model
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-06,
rope_theta=10000.0,
attention_dropout=0.0,
tie_word_embeddings=False,
use_scaled_init_for_output_weights=False,
scale_mask_softmax_fusion=False,
amp_enabled=True,
# Inference
is_encoder_decoder=False,
max_length=256,
min_length=0,
do_sample=False,
early_stopping=False,
num_beams=1,
num_beam_groups=1,
diversity_penalty=0.0,
temperature=0.7,
top_k=20,
top_p=0.8,
typical_p=1.0,
repetition_penalty=1.05,
length_penalty=1.0,
no_repeat_ngram_size=0,
encoder_no_repeat_ngram_size=0,
num_return_sequences=1,
chunk_size_feed_forward=0,
output_scores=False,
use_cache=True,
bos_token_id=151643,
eos_token_id=151645,
pad_token_id=151643,
# train
pretrained_model_path="/root/models/Qwen1.5-7B-Chat",
)

cfg = DictConfig(cfg)

model = LazyCall(Qwen2ForCausalLM)(cfg=cfg)
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(Qwen2Tokenizer)(
# vocab_file="/root/models/Qwen1.5-7B/vocab.json",
# merges_file="/root/models/Qwen/Qwen1.5-7B/merges.txt",
)
97 changes: 97 additions & 0 deletions projects/Qwen/configs/qwen_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os

from omegaconf import OmegaConf

from configs.common.models.graph import graph
from configs.common.optim import optim
from configs.common.train import train
from libai.config import LazyCall
from libai.data.build import build_nlp_test_loader, build_nlp_train_loader
from libai.evaluation import PPLEvaluator
from libai.scheduler import WarmupExponentialLR
from projects.Qwen.configs.qwen_config import cfg
from projects.Qwen.qwen2 import Qwen2ForCausalLM
from projects.Qwen.tokenizer import Qwen2Tokenizer
from projects.Qwen.qwen_dataset import QwenDataset

# Hyperparameters
weight_decay = 0.1
learning_rate = 5e-5
dataset_path = "./alpaca_data"
pretrained_model_path = "/root/models/Qwen1.5-7B-Chat"

# graph & optim
graph["enabled"] = False
optim.update(
dict(
lr=learning_rate,
weight_decay=weight_decay,
)
)

# tokenize
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(Qwen2Tokenizer)(
vocab_file=pretrained_model_path + "/vocab.json",
merges_file=pretrained_model_path + "/merges.txt",
)


# model
cfg.pretrained_model_path = pretrained_model_path
model = LazyCall(Qwen2ForCausalLM)(cfg=cfg)

# datasets
dataloader = OmegaConf.create()
dataloader.train = LazyCall(build_nlp_train_loader)(
dataset=[
LazyCall(QwenDataset)(
path=os.path.join(dataset_path, "train"), tokenizer=tokenization.tokenizer
)
],
)
dataloader.test = [
LazyCall(build_nlp_test_loader)(
dataset=LazyCall(QwenDataset)(
path=os.path.join(dataset_path, "test"), tokenizer=tokenization.tokenizer
),
),
]

train.update(
dict(
output_dir="./sft_result",
train_micro_batch_size=1,
test_micro_batch_size=1,
train_epoch=1,
train_iter=1,
log_period=1,
warmup_ratio=1 / 3,
num_accumulation_steps=1,
rdma_enabled=False,
amp=dict(enabled=True),
activation_checkpoint=dict(enabled=True),
checkpointer=dict(
period=5000,
max_to_keep=20,
),
dist=dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=8,
pipeline_num_layers=cfg.hidden_layers,
),
evaluation=dict(
enabled=False,
evaluator=LazyCall(PPLEvaluator)(),
eval_period=1000,
eval_iter=1e5,
),
scheduler=LazyCall(WarmupExponentialLR)(
warmup_factor=0.0,
gamma=1.0,
warmup_method="linear",
),
)
)
150 changes: 150 additions & 0 deletions projects/Qwen/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import click

from libai.config import try_get_key
from libai.engine import DefaultTrainer
from libai.inference.basic import BasePipeline
from libai.utils import distributed as dist


class TextGenerationPipeline(BasePipeline):
def load_pretrain_weight(self, libai_cfg_model, model_path, mode="huggingface"):
"""load pretrained model.
Args:
libai_cfg_model (libai.models): Lazy config Model in Libai, you can import it
by `from libai.config.configs.common.models.bert
import pretrain_model as libai_cfg_model`
model_path (str): The directory path of pretrained model,
"""
if mode == "huggingface":
from projects.Qwen.utils.qwen2_loader import Qwen2LoaderHuggerFace

model_loader = Qwen2LoaderHuggerFace(
libai_cfg_model,
libai_cfg_model.cfg,
model_path,
)
model = model_loader.load()
model.eval()
return model

elif mode == "libai":
from projects.Qwen.utils.qwen2_loader import Qwen2LoaderLiBai

model_loader = Qwen2LoaderLiBai(
libai_cfg_model,
libai_cfg_model.cfg,
model_path,
)
model = model_loader.load()
model.eval()
return model

elif mode == "random":
from libai.engine import DefaultTrainer

return DefaultTrainer.build_model(self.cfg)
else:
raise NotImplementedError

def _parse_parameters(self, **pipeline_parameters):
preprocess_params = {}
forward_params = {**pipeline_parameters}
postprocess_params = {}

return preprocess_params, forward_params, postprocess_params

def preprocess(self, inputs, **kwargs) -> dict:
# tokenizer encoderW
import oneflow as flow

inputs = flow.tensor(self.tokenizer.encode(inputs, add_bos=True, padding=True))

inputs = {
"input_ids": inputs,
}

return inputs

def forward(self, inputs, **kwargs) -> dict:
inputs = dist.convert_to_distributed_default_setting(inputs["input_ids"])
outputs = self.model.generate(inputs, max_length=50, **kwargs)
return {"return_ids": outputs}

def postprocess(self, model_output_dict, **kwargs) -> dict:
return_ids = model_output_dict["return_ids"]
records = [
{"generated_text": self.tokenizer.decode(return_ids[i])}
for i in range(return_ids.size(0))
]
return records

def build_tokenizer(self, cfg):
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer_cfg = cfg.tokenization.tokenizer
if "vocab_file" not in tokenizer_cfg:
# If "vocab_file" does not exist in the tokenizer's config,
# set it to default as f"{model_path}/vocab.json"
tokenizer_cfg.vocab_file = str(Path(self.model_path).joinpath("vocab.json"))
if "merges_file" not in tokenizer_cfg:
# If "merges_file" does not exist in the tokenizer's config,
# set it to default as f"{model_path}/merges.txt"
tokenizer_cfg.merges_file = str(Path(self.model_path).joinpath("merges.txt"))
tokenizer = DefaultTrainer.build_tokenizer(cfg)
return tokenizer


@click.command()
@click.option(
"--config_file",
default="projects/Qwen/configs/qwen_config.py",
help="Path to the configuration file.",
)
@click.option("--model_path", default=None, help="Path to the model checkpoint.")
@click.option(
"--mode",
default="libai",
help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.",
)
@click.option(
"--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'."
)
def main(config_file, model_path, mode, device):
pipeline = TextGenerationPipeline(
config_file,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_num_layers=32,
model_path=model_path,
mode=mode,
device=device,
)

text = ["给出3点关于保持身体健康的意见。"]

output = pipeline(inputs=text)
if dist.is_main_process():
print(output)


if __name__ == "__main__":
main()
Loading

0 comments on commit f95734e

Please sign in to comment.