From 1efccd8a127954f4a0c2804c073f335a167ad711 Mon Sep 17 00:00:00 2001 From: XIE Xuan Date: Thu, 5 Sep 2024 10:25:38 +0800 Subject: [PATCH] Llama device (#548) * update llama for multi devices * xpu and npu config files * update device for inference * update * update * update README * update * format * format * fix * feat: support third-party oneflow device extension (#549) * feat: support third-party device oneflow extentions also, refactor the build process of model and tokenizer using pretrained_model_path cofnig * refactor: remove unnecessary config and warnings * docs: update readme for commands to run llama on npu and xpu * fix import order * update * update * fix: skip lint on oneflow third-party imports --------- Co-authored-by: Qunhong Zeng <871206929@qq.com> --- libai/inference/basic.py | 24 ++++++++++++- libai/utils/distributed.py | 12 +++++++ projects/Llama/{readme.md => README.md} | 15 +++++++- projects/Llama/configs/llama_config.py | 2 +- projects/Llama/pipeline.py | 48 ++++++++++++++----------- projects/Llama/tokenizer.py | 4 +-- projects/Llama/utils/llama_loader.py | 4 +++ 7 files changed, 83 insertions(+), 26 deletions(-) rename projects/Llama/{readme.md => README.md} (84%) diff --git a/libai/inference/basic.py b/libai/inference/basic.py index 94d3f1781..b869e56cc 100644 --- a/libai/inference/basic.py +++ b/libai/inference/basic.py @@ -15,6 +15,7 @@ import logging from abc import ABCMeta, abstractmethod +from pathlib import Path from typing import Any, Dict import oneflow as flow @@ -43,6 +44,7 @@ def __init__( pipeline_num_layers=None, model_path=None, mode="libai", + device="cuda", **kwargs, ): # init cfg @@ -60,10 +62,21 @@ def __init__( pipeline_stage_id, pipeline_num_layers, ) + self.device = device + self.cfg.train.dist.device_type = device dist.setup_dist_util(self.cfg.train.dist) logger.info(self.cfg.train.dist) # initial and load model + self.model_path = model_path + if self.model_path is not None: + # If a model_path is provided in BasePipeline, + # we use it with priority, overwrite the pretrained_model_path in config + self.cfg.model.cfg.pretrained_model_path = self.model_path + else: + # If the model_path in BasePipeline is None, then use the one from the config + assert "pretrained_model_path" in self.cfg.model.cfg + self.model_path = self.cfg.model.cfg.pretrained_model_path self.model = self.load_pretrain_weight(self.cfg.model, model_path, mode=mode) self.model._apply(dist.convert_to_distributed_default_setting) @@ -134,6 +147,13 @@ def load_pretrain_weight( def build_tokenizer(self, cfg): tokenizer = None if try_get_key(cfg, "tokenization") is not None: + tokenizer_cfg = cfg.tokenization.tokenizer + if "pretrained_model_path" not in tokenizer_cfg: + # If "pretrained_model_path" does not exist in the tokenizer's config, + # set it to default as f"{model_path}/tokenizer.model" + tokenizer_cfg.pretrained_model_path = str( + Path(self.model_path).joinpath("tokenizer.model") + ) tokenizer = DefaultTrainer.build_tokenizer(cfg) return tokenizer @@ -167,7 +187,9 @@ def to_local(self, model_outputs_dict): for key, value in model_outputs_dict.items(): if isinstance(value, flow.Tensor) and value.is_global: model_outputs_dict[key] = dist.ttol( - value, ranks=[0] if value.placement.ranks.ndim == 1 else [[0]] + value, + device=self.device, + ranks=[0] if value.placement.ranks.ndim == 1 else [[0]], ) if flow.cuda.is_available(): dist.synchronize() diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index f84313fd7..f64479210 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -72,6 +72,18 @@ def _init_distributed_env(self, cfg): # Add set device type self._device_type = try_get_key(cfg, "device_type", default="cuda") + if self._device_type == "npu": + try: + import oneflow_npu # noqa: F401 + except ImportError: + raise ImportError("'oneflow_npu' is missing. Install it to use NPU devices.") + elif self._device_type == "xpu": + try: + import oneflow_xpu # noqa: F401 + except ImportError: + raise ImportError("'oneflow_xpu' is missing. Install it to use NPU devices.") + elif self._device_type not in ("cuda", "npu", "xpu", "cpu"): + raise NotImplementedError(f"Unsupported device {self._device_type}") def _init_parallel_size(self, cfg): diff --git a/projects/Llama/readme.md b/projects/Llama/README.md similarity index 84% rename from projects/Llama/readme.md rename to projects/Llama/README.md index 9adb3d925..f58e416c1 100644 --- a/projects/Llama/readme.md +++ b/projects/Llama/README.md @@ -44,4 +44,17 @@ python projects/Llama/utils/eval_adapter.py - Adjust the parameters in the `projects/Llama/pipeline.py`, and running: ```bash bash tools/infer.sh projects/Llama/pipeline.py 8 -``` \ No newline at end of file +``` + +## npu/xpu example + +- npu +```bash +python projects/Llama/pipeline.py --device=npu --mode=huggingface --model_path /your/model/path +``` + +- xpu +```bash +python projects/Llama/pipeline.py --device=xpu --mode=huggingface --model_path /your/model/path +``` + diff --git a/projects/Llama/configs/llama_config.py b/projects/Llama/configs/llama_config.py index 01d208016..36f95d126 100644 --- a/projects/Llama/configs/llama_config.py +++ b/projects/Llama/configs/llama_config.py @@ -57,5 +57,5 @@ tokenization = OmegaConf.create() tokenization.make_vocab_size_divisible_by = 1 tokenization.tokenizer = LazyCall(LlamaTokenizer)( - pretrained_model_path="meta-llama/Llama-2-7b-hf/tokenizer.model" + # pretrained_model_path="meta-llama/Llama-2-7b-hf/tokenizer.model" ) diff --git a/projects/Llama/pipeline.py b/projects/Llama/pipeline.py index bea4a2f56..4b65d2895 100644 --- a/projects/Llama/pipeline.py +++ b/projects/Llama/pipeline.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import click + from libai.inference.basic import BasePipeline from libai.utils import distributed as dist @@ -67,7 +69,7 @@ def _parse_parameters(self, **pipeline_parameters): def preprocess(self, inputs, **kwargs) -> dict: # tokenizer encoderW - inputs = self.tokenizer.tokenize(inputs, add_bos=True, padding=True) + inputs = self.tokenizer.tokenize(inputs, add_bos=True, padding=True, device=self.device) inputs = { "input_ids": inputs, } @@ -87,31 +89,31 @@ def postprocess(self, model_output_dict, **kwargs) -> dict: return records -if __name__ == "__main__": - # ----- load huggingface checkpoint ----- - # pipeline = TextGenerationPipeline( - # "projects/Llama/configs/llama_config.py", - # data_parallel=1, - # tensor_parallel=1, - # pipeline_parallel=1, - # pipeline_num_layers=32, - # model_path="", - # mode="huggingface", - # ) - - # output = pipeline(inputs=text) - # if dist.is_main_process(): - # print(output) - - # ----- load libai checkpoint ----- +@click.command() +@click.option( + "--config_file", + default="projects/Llama/configs/llama_config.py", + help="Path to the configuration file.", +) +@click.option("--model_path", default=None, help="Path to the model checkpoint.") +@click.option( + "--mode", + default="libai", + help="Mode for the dataloader pipeline, e.g., 'libai' or 'huggingface'.", +) +@click.option( + "--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'." +) +def main(config_file, model_path, mode, device): pipeline = TextGenerationPipeline( - "projects/Llama/configs/llama_config.py", + config_file, data_parallel=1, tensor_parallel=1, pipeline_parallel=1, pipeline_num_layers=32, - model_path="", - mode="libai", + model_path=model_path, + mode=mode, + device=device, ) text = [ @@ -120,3 +122,7 @@ def postprocess(self, model_output_dict, **kwargs) -> dict: output = pipeline(inputs=text) if dist.is_main_process(): print(output) + + +if __name__ == "__main__": + main() diff --git a/projects/Llama/tokenizer.py b/projects/Llama/tokenizer.py index 56aca8336..1598a1dbe 100644 --- a/projects/Llama/tokenizer.py +++ b/projects/Llama/tokenizer.py @@ -75,9 +75,9 @@ def tokenize( if add_eos: tokens = [token + [self.eos_token_id] for token in tokens] - if device == "cuda": + if device: sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])) - placement = kwargs.get("placement", flow.placement("cuda", [0])) + placement = kwargs.get("placement", flow.placement(device, [0])) return_token_ids = flow.tensor(tokens, sbp=sbp, placement=placement, dtype=flow.long) else: return_token_ids = flow.tensor(tokens, dtype=flow.long) diff --git a/projects/Llama/utils/llama_loader.py b/projects/Llama/utils/llama_loader.py index 20b9ba258..c46cb480a 100644 --- a/projects/Llama/utils/llama_loader.py +++ b/projects/Llama/utils/llama_loader.py @@ -26,6 +26,8 @@ def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): self.base_model_prefix_1 = "model" self.base_model_prefix_2 = "model" + if not pretrained_model_path: + self.pretrained_model_path = libai_cfg.pretrained_model_path def _convert_state_dict(self, flow_state_dict, cfg): """Convert state_dict's keys to match model. @@ -104,3 +106,5 @@ class LlamaLoaderLiBai(ModelLoaderLiBai): def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): super().__init__(model, libai_cfg, pretrained_model_path, **kwargs) self.base_model_prefix_2 = "model" + if not pretrained_model_path: + self.pretrained_model_path = libai_cfg.pretrained_model_path