Skip to content

Commit

Permalink
Chatglm devices (#550)
Browse files Browse the repository at this point in the history
* fix bfloat16

* update tokenizer encode

* support python 3.10-

* support python 3.10-

* support python 3.10-

* click

* fix tokenizer path

* compatiable to 3.8

* udpate README

* format

* isort

* rm

* modify workflow
  • Loading branch information
ShawnXuan authored Sep 10, 2024
1 parent 1efccd8 commit 7aeb4ca
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 83 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/py.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ env:
ONEFLOW_SRC: oneflow-src
on:
pull_request:
types: [review_requested]
types: [opened, review_requested, ready_for_review, synchronize, unlocked]
branches:
- "*"
workflow_dispatch:
Expand Down
5 changes: 5 additions & 0 deletions libai/models/utils/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,11 @@ def _convert_tensor(self, tensor):
Returns:
flow.Tensor: The target tensor.
"""
import torch

if tensor.dtype == torch.bfloat16:
data = tensor.detach().half().cpu().numpy()
return flow.Tensor(data)
return flow.Tensor(tensor.detach().cpu().numpy())

def _convert_tensors(self, torch_state_dict):
Expand Down
8 changes: 6 additions & 2 deletions libai/tokenizer/tokenization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,14 +805,18 @@ def _convert_token_to_id_with_added_voc(self, token):
def _convert_token_to_id(self, token):
raise NotImplementedError

def encode(self, text, return_tensors=None, is_global=False, **kwargs):
def encode(self, text, return_tensors=None, is_global=False, device="cuda", **kwargs):
if isinstance(text, str):
tokens = self.tokenize(text)
token_ids = self.convert_tokens_to_ids(tokens)
if hasattr(self, "build_inputs_with_special_tokens"):
token_ids = self.build_inputs_with_special_tokens(token_ids)
token_ids = self.convert_to_tensors(
token_ids, return_tensors=return_tensors, is_global=is_global, **kwargs
token_ids,
return_tensors=return_tensors,
is_global=is_global,
device=device,
**kwargs,
)
return token_ids
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
Expand Down
9 changes: 9 additions & 0 deletions projects/ChatGLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,12 @@ python projects/ChatGLM/pipeline.py

### ChatGLM Lora Inference
- set `projects/ChatGLM/configs/chatglm_config.py`, lora_enable=True, same step with no lora.

### npu/xpu/cuda example
```python
python projects/ChatGLM/pipeline.py --model_path=/data0/hf_models/chatglm/chatglm2-6b --mode=huggingface --device=npu

python projects/ChatGLM/pipeline.py --model_path=/root/models/chatglm2-6b/ --mode=huggingface --device=xpu

python projects/ChatGLM/pipeline.py --model_path=/root/models/chatglm2-6b/ --mode=huggingface --device=cuda
```
6 changes: 2 additions & 4 deletions projects/ChatGLM/configs/chatglm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
output_scores=False,
output_hidden_states=False,
# train
pretrained_model_path=os.environ["CHATGLM_HF_DIR"],
pretrained_model_path="chatglm/chatglm2-6b",
# lora_cfg
lora_enable=False,
lora_cfg=dict(
Expand All @@ -86,6 +86,4 @@
model = LazyCall(ChatGLMForConditionalGeneration)(cfg=cfg)
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(ChatGLMTokenizer)(
vocab_file=f"{os.environ['CHATGLM_HF_DIR']}/tokenizer.model"
)
tokenization.tokenizer = LazyCall(ChatGLMTokenizer)()
16 changes: 8 additions & 8 deletions projects/ChatGLM/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import warnings
from abc import ABC
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Tuple, Union

import oneflow as flow
import oneflow.nn as nn
Expand All @@ -41,18 +41,18 @@ class BaseTunerLayer(ABC):
active_adapter = None

# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str] = ()
adapter_layer_names: Tuple[str, ...] = ()
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str] = ()
other_param_names: Tuple[str, ...] = ()

# indicates whether all adapters should be disabled
_disable_adapters: bool = False

# the currently active adapter(s)
_active_adapter: str | list[str] = "default"
_active_adapter: Union[str, List[str]] = "default"

# List all merged adapters
merged_adapters: list[str] = []
merged_adapters: List[str] = []

def get_base_layer(self) -> nn.Module:
"""
Expand All @@ -72,7 +72,7 @@ def weight(self) -> flow.Tensor:
weight = base_layer.weight
return weight

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
raise NotImplementedError

def unmerge(self) -> None:
Expand Down Expand Up @@ -119,7 +119,7 @@ def enable_adapters(self, enabled: bool) -> None:
layer.requires_grad_(False)
self._disable_adapters = True

def set_adapter(self, adapter_names: str | list[str]) -> None:
def set_adapter(self, adapter_names: Union[str, List[str]]) -> None:
"""Set the active adapter(s).
Args:
Expand All @@ -142,7 +142,7 @@ def set_adapter(self, adapter_names: str | list[str]) -> None:

self._active_adapter = adapter_names

def _all_available_adapter_names(self) -> list[str]:
def _all_available_adapter_names(self) -> List[str]:
"""Return a sorted list of all available adapter names"""
adapter_names = set()
for name in self.adapter_layer_names + self.other_param_names:
Expand Down
8 changes: 4 additions & 4 deletions projects/ChatGLM/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dataclasses import asdict
from enum import Enum
from itertools import chain
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from oneflow import nn
from tqdm import tqdm
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, model, peft_config, adapter_name: str) -> None:
self.inject_adapter(self.model, adapter_name)

@property
def active_adapters(self) -> list[str]:
def active_adapters(self) -> List[str]:
if isinstance(self.active_adapter, str):
return [self.active_adapter]
# is already a list of str
Expand Down Expand Up @@ -192,7 +192,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
if adapter_name in n:
p.requires_grad = False

def merge_adapter(self, safe_merge=False, adapter_names: Optional[list[str]] = None) -> None:
def merge_adapter(self, safe_merge=False, adapter_names: Optional[List[str]] = None) -> None:
"""
This method merges the adapter layers into the base model.
Expand Down Expand Up @@ -404,7 +404,7 @@ def disable_adapter_layers(self) -> None:
warnings.warn(msg)
self._set_adapter_layers(enabled=False)

def set_adapter(self, adapter_name: str | list[str]) -> None:
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""Set the active adapter(s).
Args:
Expand Down
4 changes: 2 additions & 2 deletions projects/ChatGLM/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# limitations under the License.

import re
from typing import List
from typing import List, Optional, Union

import oneflow as flow

COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]


def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
def check_target_module_exists(config, key: str) -> Union[bool, Optional[re.Match]]:
"""A helper method to check if the passed module's key name matches
any of the target modules in the adapter_config.
Expand Down
110 changes: 48 additions & 62 deletions projects/ChatGLM/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Union

import click

from libai.config import try_get_key
from libai.engine import DefaultTrainer
from libai.inference.basic import BasePipeline
from libai.utils import distributed as dist

Expand Down Expand Up @@ -81,7 +86,7 @@ def load_pretrain_weight(self, libai_cfg_model, model_path, mode="huggingface"):
return model

elif mode == "random":
from libai.engine import DefaultTrainer
# from libai.engine import DefaultTrainer

return DefaultTrainer.build_model(self.cfg)
else:
Expand All @@ -94,19 +99,32 @@ def _parse_parameters(self, **pipeline_parameters):

return preprocess_params, forward_params, postprocess_params

def preprocess(self, sentence: str | list, **kwargs) -> dict:
def preprocess(self, sentence: Union[str, list], **kwargs) -> dict:
#
if type(sentence) is str:
inputs = {
"inputs": sentence,
}
else:
inputs = self.tokenizer.encode(sentence, return_tensors="of", is_global=True)
inputs = self.tokenizer.encode(
sentence, return_tensors="of", is_global=True, device=self.device
)
inputs = {
"input_ids": inputs,
}
return inputs

def build_tokenizer(self, cfg):
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer_cfg = cfg.tokenization.tokenizer
if "vocab_file" not in tokenizer_cfg:
# If "vocab_file" does not exist in the tokenizer's config,
# set it to default as f"{model_path}/tokenizer.model"
tokenizer_cfg.vocab_file = str(Path(self.model_path).joinpath("tokenizer.model"))
tokenizer = DefaultTrainer.build_tokenizer(cfg)
return tokenizer

def forward(self, inputs, **kwargs) -> dict:
if "input_ids" not in inputs:
if "history" in kwargs:
Expand Down Expand Up @@ -143,85 +161,53 @@ def reset_conversation(self):
self.history = []


if __name__ == "__main__":
# ----- load huggingface checkpoint -----
@click.command()
@click.option(
"--config_file",
default="projects/ChatGLM/configs/chatglm_config.py",
help="Path to the configuration file.",
)
@click.option("--model_path", default=None, help="Path to the model checkpoint.")
@click.option(
"--mode",
default="libai",
help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.",
)
@click.option(
"--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'."
)
def main(config_file, model_path, mode, device):
text = "浏览器输入www.baidu.com 并且显示网页,从计算机网络的角度说明实现的全过程"
text2 = (
"5600分为A、B、C三部分,如果A比C的比例是1/7:1/7:1/14,那么A比C多多少?\n"
"选项:\n(A) 300\n(B) 992 \n(C) 1120\n(D) 552\n(E) 312 让我们先想想。一些随机推理:"
)
texts = [
text,
text2,
"a dog is flying on the sky",
"Wikipedia is a free online",
"what is beam search?",
"what is beam search?",
]
pipeline = TextGenerationPipeline(
"projects/ChatGLM/configs/chatglm_config.py",
config_file,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_num_layers=28,
model_path=os.environ["CHATGLM_HF_DIR"],
mode="huggingface",
model_path=model_path,
mode=mode,
device=device,
)
pipeline.model = pipeline.model.half()

if isinstance(texts, list):
output = pipeline(inputs=texts, do_sample=False, max_length=50)
output = pipeline(inputs=texts, do_sample=False, max_length=400)
if dist.is_main_process():
for text, record in zip(texts, output):
print(f"Q:{text}||A:{record}")

# if isinstance(text, str):
# output = pipeline(inputs=text, do_sample=False, max_length=400)
# if dist.is_main_process():
# for record in output:
# print(record["generated_text"])
# pipeline.reset_conversation()
# output = pipeline(inputs=text2, do_sample=False, max_length=400)
# if dist.is_main_process():
# for record in output:
# print(record["generated_text"])

# # ----- load libai checkpoint -----
# pipeline = TextGenerationPipeline(
# "projects/ChatGLM/configs/chatglm_config.py",
# data_parallel=1,
# tensor_parallel=1,
# pipeline_parallel=1,
# pipeline_num_layers=28,
# model_path="/home/lixin/codes/libai/lora_sft_result/model_final/model",
# mode="libai",
# )
# pipeline.model = pipeline.model.half()

# if isinstance(texts, list):
# output = pipeline(inputs=texts, do_sample=False, max_length=50)
# if dist.is_main_process():
# for text, record in zip(texts, output):
# print(f"Q:{text}||A:{record}")

# if isinstance(text, str):
# output = pipeline(inputs=text, do_sample=False, max_length=400)
# if dist.is_main_process():
# for record in output:
# print(record['generated_text'])
# pipeline.reset_conversation()
# output = pipeline(inputs=text2, do_sample=False, max_length=400)
# if dist.is_main_process():
# for record in output:
# print(record['generated_text'])

# ----- pure huggingface predict -----
# from transformers import AutoModel, AutoTokenizer

# tokenizer = AutoTokenizer.from_pretrained(glm_model_path, trust_remote_code=True)
# model = AutoModel.from_pretrained(glm_model_path, trust_remote_code=True).half().cuda()
# model = model.eval()
# history = []
# for _ in range(1):
# response, history = model.chat(
# tokenizer, text, history=history, do_sample=False, max_length=400
# )
# print(response)

if __name__ == "__main__":
main()

0 comments on commit 7aeb4ca

Please sign in to comment.