Skip to content

Commit

Permalink
feat: support third-party oneflow device extension (#549)
Browse files Browse the repository at this point in the history
* feat: support third-party device oneflow extentions

also, refactor the build process of model and tokenizer using
pretrained_model_path cofnig

* refactor: remove unnecessary config and warnings

* docs: update readme for commands to run llama on npu and xpu
  • Loading branch information
0x404 authored Sep 4, 2024
1 parent d4bd6db commit 593937f
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 140 deletions.
20 changes: 18 additions & 2 deletions libai/inference/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, Dict
from pathlib import Path

import oneflow as flow

Expand Down Expand Up @@ -62,12 +63,20 @@ def __init__(
pipeline_num_layers,
)
self.device = device
if device:
self.cfg.train.dist.device_type = device
self.cfg.train.dist.device_type = device
dist.setup_dist_util(self.cfg.train.dist)
logger.info(self.cfg.train.dist)

# initial and load model
self.model_path = model_path
if self.model_path is not None:
# If a model_path is provided in BasePipeline,
# we use it with priority, overwrite the pretrained_model_path in config
self.cfg.model.cfg.pretrained_model_path = self.model_path
else:
# If the model_path in BasePipeline is None, then use the one from the config
assert "pretrained_model_path" in self.cfg.model.cfg
self.model_path = self.cfg.model.cfg.pretrained_model_path

self.model = self.load_pretrain_weight(self.cfg.model, model_path, mode=mode)
self.model._apply(dist.convert_to_distributed_default_setting)
Expand Down Expand Up @@ -138,6 +147,13 @@ def load_pretrain_weight(
def build_tokenizer(self, cfg):
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer_cfg = cfg.tokenization.tokenizer
if "pretrained_model_path" not in tokenizer_cfg:
# If "pretrained_model_path" does not exist in the tokenizer's config,
# set it to default as f"{model_path}/tokenizer.model"
tokenizer_cfg.pretrained_model_path = str(
Path(self.model_path).joinpath("tokenizer.model")
)
tokenizer = DefaultTrainer.build_tokenizer(cfg)
return tokenizer

Expand Down
16 changes: 16 additions & 0 deletions libai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ def _init_distributed_env(self, cfg):

# Add set device type
self._device_type = try_get_key(cfg, "device_type", default="cuda")
if self._device_type == "npu":
try:
import oneflow_npu
except ImportError:
raise ImportError(
"The module 'oneflow_npu' is not installed. Please install it to use NPU devices."
)
elif self._device_type == "xpu":
try:
import oneflow_xpu
except ImportError:
raise ImportError(
"The module 'oneflow_xpu' is not installed. Please install it to use XPU devices."
)
elif self._device_type not in ("cuda", "npu", "xpu", "cpu"):
raise NotImplementedError(f"Unsupported device {self._device_type}")

def _init_parallel_size(self, cfg):

Expand Down
4 changes: 2 additions & 2 deletions projects/Llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ bash tools/infer.sh projects/Llama/pipeline.py 8

- npu
```bash
python projects/Llama/pipeline.py --device=npu --mode=huggingface --config_file=projects/Llama/configs/llama_config_npu.py
python projects/Llama/pipeline.py --device=npu --mode=huggingface --model_path /your/model/path
```

- xpu
```bash
python projects/Llama/pipeline.py --device=xpu --mode=huggingface --config_file=projects/Llama/configs/llama_config_xpu.py
python projects/Llama/pipeline.py --device=xpu --mode=huggingface --model_path /your/model/path
```

2 changes: 1 addition & 1 deletion projects/Llama/configs/llama_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@
tokenization = OmegaConf.create()
tokenization.make_vocab_size_divisible_by = 1
tokenization.tokenizer = LazyCall(LlamaTokenizer)(
pretrained_model_path="meta-llama/Llama-2-7b-hf/tokenizer.model"
# pretrained_model_path="meta-llama/Llama-2-7b-hf/tokenizer.model"
)
64 changes: 0 additions & 64 deletions projects/Llama/configs/llama_config_npu.py

This file was deleted.

64 changes: 0 additions & 64 deletions projects/Llama/configs/llama_config_xpu.py

This file was deleted.

8 changes: 1 addition & 7 deletions projects/Llama/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def postprocess(self, model_output_dict, **kwargs) -> dict:
default="projects/Llama/configs/llama_config.py",
help="Path to the configuration file.",
)
@click.option("--model_path", default="", help="Path to the model checkpoint.")
@click.option("--model_path", default=None, help="Path to the model checkpoint.")
@click.option(
"--mode",
default="libai",
Expand All @@ -105,12 +105,6 @@ def postprocess(self, model_output_dict, **kwargs) -> dict:
"--device", default="cuda", help="Device to run the model on, e.g., 'cuda', 'xpu', 'npu'."
)
def main(config_file, model_path, mode, device):
if model_path:
print(
"Note: The '--model_path' option is for the model checkpoint only. "
"Please configure 'tokenization.tokenizer.pretrained_model_path' "
"directly in the config file."
)
pipeline = TextGenerationPipeline(
config_file,
data_parallel=1,
Expand Down

0 comments on commit 593937f

Please sign in to comment.