From 31d099b435a1cbcbe5b9d743c1ce7e8df0f97529 Mon Sep 17 00:00:00 2001 From: zhangyubo0722 Date: Wed, 16 Oct 2024 13:17:15 +0000 Subject: [PATCH] add other devices supported model list --- .../formula_recognition/LaTeX_OCR_rec.yaml | 2 +- paddlex/inference/utils/new_ir_blacklist.py | 2 +- paddlex/inference/utils/pp_option.py | 18 +- paddlex/modules/base/evaluator.py | 11 +- paddlex/modules/base/exportor.py | 11 +- paddlex/modules/base/trainer.py | 11 +- paddlex/paddlex_cli.py | 15 ++ paddlex/utils/device.py | 24 +- paddlex/utils/file_interface.py | 1 - paddlex/utils/other_devices_model_list.py | 224 ++++++++++++++++++ 10 files changed, 294 insertions(+), 25 deletions(-) create mode 100644 paddlex/utils/other_devices_model_list.py diff --git a/paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml b/paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml index dae4c941c..a7e499e06 100644 --- a/paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml +++ b/paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml @@ -2,7 +2,7 @@ Global: model: LaTeX_OCR_rec mode: check_dataset # check_dataset/train/evaluate/predict dataset_dir: "./dataset/ocr_rec_latexocr_dataset_example" - device: gpu:0 + device: gpu:0,1,2,3 output: "output" CheckDataset: diff --git a/paddlex/inference/utils/new_ir_blacklist.py b/paddlex/inference/utils/new_ir_blacklist.py index cfd6bbb2a..c664fbe84 100644 --- a/paddlex/inference/utils/new_ir_blacklist.py +++ b/paddlex/inference/utils/new_ir_blacklist.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -NEWIR_BLOCKLIST = [ +NEWIR_BLACKLIST = [ "FasterRCNN-ResNet34-FPN", "FasterRCNN-ResNet50", "FasterRCNN-ResNet50-FPN", diff --git a/paddlex/inference/utils/pp_option.py b/paddlex/inference/utils/pp_option.py index 1af258fa5..ca7995471 100644 --- a/paddlex/inference/utils/pp_option.py +++ b/paddlex/inference/utils/pp_option.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...utils.device import parse_device, set_env_for_device, get_default_device +from ...utils.device import ( + parse_device, + set_env_for_device, + get_default_device, + check_device, +) from ...utils import logging -from .new_ir_blacklist import NEWIR_BLOCKLIST +from .new_ir_blacklist import NEWIR_BLACKLIST class PaddlePredictorOption(object): @@ -28,7 +33,6 @@ class PaddlePredictorOption(object): "mkldnn", "mkldnn_bf16", ) - SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu") def __init__(self, model_name=None, **kwargs): super().__init__() @@ -61,7 +65,7 @@ def _get_default_config(self): "cpu_threads": 1, "trt_use_static": False, "delete_pass": [], - "enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False, + "enable_new_ir": True if self.model_name not in NEWIR_BLACKLIST else False, "batch_size": 1, # only for trt } @@ -101,11 +105,7 @@ def device(self, device: str): if not device: return device_type, device_ids = parse_device(device) - if device_type not in self.SUPPORT_DEVICE: - support_run_mode_str = ", ".join(self.SUPPORT_DEVICE) - raise ValueError( - f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}." - ) + check_device(self.model_name, device_type) self._update("device", device_type) device_id = device_ids[0] if device_ids is not None else 0 self._update("device_id", device_id) diff --git a/paddlex/modules/base/evaluator.py b/paddlex/modules/base/evaluator.py index 72e8da019..e5fe74962 100644 --- a/paddlex/modules/base/evaluator.py +++ b/paddlex/modules/base/evaluator.py @@ -17,7 +17,12 @@ from abc import ABC, abstractmethod from .build_model import build_model -from ...utils.device import update_device_num, set_env_for_device +from ...utils.device import ( + update_device_num, + set_env_for_device, + parse_device, + check_device, +) from ...utils.misc import AutoRegisterABCMetaClass from ...utils.config import AttrDict from ...utils.logging import * @@ -138,8 +143,10 @@ def get_device(self, using_device_number: int = None) -> str: Returns: str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`. """ + device_type, device_ids = parse_device(self.global_config.device) + check_device(self.global_config.model, device_type) if using_device_number: - return update_device_num(self.global_config.device, using_device_number) + return update_device_num(device_type, device_ids, using_device_number) set_env_for_device(self.global_config.device) return self.global_config.device diff --git a/paddlex/modules/base/exportor.py b/paddlex/modules/base/exportor.py index b8fde23d0..4653647cb 100644 --- a/paddlex/modules/base/exportor.py +++ b/paddlex/modules/base/exportor.py @@ -17,7 +17,12 @@ from abc import ABC, abstractmethod from .build_model import build_model -from ...utils.device import update_device_num, set_env_for_device +from ...utils.device import ( + update_device_num, + set_env_for_device, + parse_device, + check_device, +) from ...utils.misc import AutoRegisterABCMetaClass from ...utils.config import AttrDict from ...utils import logging @@ -103,8 +108,10 @@ def get_device(self, using_device_number: int = None) -> str: Returns: str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`. """ + device_type, device_ids = parse_device(self.global_config.device) + check_device(self.global_config.model, device_type) if using_device_number: - return update_device_num(self.global_config.device, using_device_number) + return update_device_num(device_type, device_ids, using_device_number) set_env_for_device(self.global_config.device) return self.global_config.device diff --git a/paddlex/modules/base/trainer.py b/paddlex/modules/base/trainer.py index 48e0eb394..a4b7d9eaf 100644 --- a/paddlex/modules/base/trainer.py +++ b/paddlex/modules/base/trainer.py @@ -16,7 +16,12 @@ from abc import ABC, abstractmethod from pathlib import Path from .build_model import build_model -from ...utils.device import update_device_num, set_env_for_device +from ...utils.device import ( + update_device_num, + set_env_for_device, + parse_device, + check_device, +) from ...utils.misc import AutoRegisterABCMetaClass from ...utils.config import AttrDict @@ -95,8 +100,10 @@ def get_device(self, using_device_number: int = None) -> str: Returns: str: device setting, such as: `gpu:0,1`, `npu:0,1` `cpu`. """ + device_type, device_ids = parse_device(self.global_config.device) + check_device(self.global_config.model, device_type) if using_device_number: - return update_device_num(self.global_config.device, using_device_number) + return update_device_num(device_type, device_ids, using_device_number) set_env_for_device(self.global_config.device) return self.global_config.device diff --git a/paddlex/paddlex_cli.py b/paddlex/paddlex_cli.py index 1d1e3f496..4685e84dd 100644 --- a/paddlex/paddlex_cli.py +++ b/paddlex/paddlex_cli.py @@ -16,11 +16,14 @@ import argparse import subprocess import sys +import shutil import tempfile +from pathlib import Path from . import create_pipeline from .inference.pipelines import create_pipeline_from_config, load_pipeline_config from .repo_manager import setup, get_all_supported_repo_names +from .utils.cache import CACHE_DIR from .utils import logging from .utils.interactive_get_pipeline import interactive_get_pipeline @@ -65,6 +68,7 @@ def parse_str(s): ################# install pdx ################# parser.add_argument("--install", action="store_true", default=False, help="") + parser.add_argument("--clear_cache", action="store_true", default=False, help="") parser.add_argument("plugins", nargs="*", default=[]) parser.add_argument("--no_deps", action="store_true") parser.add_argument("--platform", type=str, default="github.com") @@ -159,6 +163,15 @@ def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, po run_server(app, host=host, port=port, debug=False) +def clear_cache(): + cache_dir = Path(CACHE_DIR) / "official_models" + if cache_dir.exists() and cache_dir.is_dir(): + shutil.rmtree(cache_dir) + logging.info(f"Successfully cleared the cache models at {cache_dir}") + else: + logging.info(f"No cache models found at {cache_dir}") + + # for CLI def main(): """API for commad line""" @@ -180,6 +193,8 @@ def main(): host=args.host, port=args.port, ) + elif args.clear_cache: + clear_cache() else: if args.get_pipeline_config is not None: interactive_get_pipeline(args.get_pipeline_config, args.save_path) diff --git a/paddlex/utils/device.py b/paddlex/utils/device.py index 16ce8ff15..472731fea 100644 --- a/paddlex/utils/device.py +++ b/paddlex/utils/device.py @@ -18,8 +18,7 @@ from . import logging from .errors import raise_unsupported_device_error - -SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"] +from .other_devices_model_list import OTHER_DEVICES_MODEL_LIST def _constr_device(device_type, device_ids): @@ -38,6 +37,21 @@ def get_default_device(): return _constr_device("gpu", [avail_gpus[0]]) +def check_device(model_name, device_type): + supported_device_type = ["cpu", "gpu", "xpu", "npu", "mlu", "dcu"] + device_type = device_type.lower() + if device_type not in supported_device_type: + support_run_mode_str = ", ".join(supported_device_type) + raise ValueError( + f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}." + ) + if device_type in OTHER_DEVICES_MODEL_LIST: + if model_name not in OTHER_DEVICES_MODEL_LIST[device_type]: + raise ValueError( + f"The model '{model_name}' is not supported on {device_type}." + ) + + def parse_device(device): """parse_device""" # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html @@ -55,14 +69,10 @@ def parse_device(device): f"Device ID must be an integer. Invalid device ID: {device_id}" ) device_ids = list(map(int, device_ids)) - device_type = device_type.lower() - # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE) - assert device_type.lower() in SUPPORTED_DEVICE_TYPE return device_type, device_ids -def update_device_num(device, num): - device_type, device_ids = parse_device(device) +def update_device_num(device_type, device_ids, num): if device_ids: assert len(device_ids) >= num return _constr_device(device_type, device_ids[:num]) diff --git a/paddlex/utils/file_interface.py b/paddlex/utils/file_interface.py index f86251359..82954b631 100644 --- a/paddlex/utils/file_interface.py +++ b/paddlex/utils/file_interface.py @@ -25,7 +25,6 @@ try: import ujson as json except: - logging.error("failed to import ujson, using json instead") import json from contextlib import contextmanager diff --git a/paddlex/utils/other_devices_model_list.py b/paddlex/utils/other_devices_model_list.py new file mode 100644 index 000000000..c2f024cd9 --- /dev/null +++ b/paddlex/utils/other_devices_model_list.py @@ -0,0 +1,224 @@ +# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +OTHER_DEVICES_MODEL_LIST = { + "xpu": [ + "MobileNetV3_large_x0_5", + "MobileNetV3_large_x0_35", + "MobileNetV3_large_x0_75", + "MobileNetV3_large_x1_0", + "MobileNetV3_large_x1_25", + "MobileNetV3_small_x0_5", + "MobileNetV3_small_x0_35", + "MobileNetV3_small_x0_75", + "MobileNetV3_small_x1_0", + "MobileNetV3_small_x1_25", + "PP-HGNet_small", + "PP-LCNet_x0_5", + "PP-LCNet_x0_25", + "PP-LCNet_x0_35", + "PP-LCNet_x0_75", + "PP-LCNet_x1_0", + "PP-LCNet_x1_5", + "PP-LCNet_x2_0", + "PP-LCNet_x2_5", + "ResNet18", + "ResNet34", + "ResNet50", + "ResNet101", + "ResNet152", + "PicoDet-L", + "PicoDet-S", + "PP-YOLOE_plus-L", + "PP-YOLOE_plus-M", + "PP-YOLOE_plus-S", + "PP-YOLOE_plus-X", + "PP-LiteSeg-T", + "PP-OCRv4_mobile_det", + "PP-OCRv4_server_det", + "PP-OCRv4_mobile_rec", + "PP-OCRv4_server_rec", + "PicoDet_layout_1x", + "DLinear", + "NLinear", + "RLinear", + ], + "npu": [ + "CLIP_vit_base_patch16_224", + "CLIP_vit_large_patch14_224", + "ConvNeXt_base_224", + "ConvNeXt_base_384", + "ConvNeXt_large_224", + "ConvNeXt_large_384", + "ConvNeXt_small", + "ConvNeXt_tiny", + "MobileNetV1_x0_75", + "MobileNetV1_x1_0", + "MobileNetV2_x0_5", + "MobileNetV2_x0_25", + "MobileNetV2_x1_0", + "MobileNetV2_x1_5", + "MobileNetV2_x2_0", + "MobileNetV3_large_x0_5", + "MobileNetV3_large_x0_35", + "MobileNetV3_large_x0_75", + "MobileNetV3_large_x1_0", + "MobileNetV3_large_x1_25", + "MobileNetV3_small_x0_5", + "MobileNetV3_small_x0_35", + "MobileNetV3_small_x0_75", + "MobileNetV3_small_x1_0", + "MobileNetV3_small_x1_25", + "PP-HGNet_base", + "PP-HGNet_small", + "PP-HGNet_tiny", + "PP-HGNetV2-B0", + "PP-HGNetV2-B1", + "PP-HGNetV2-B2", + "PP-HGNetV2-B3", + "PP-HGNetV2-B4", + "PP-HGNetV2-B5", + "PP-HGNetV2-B6", + "PP-LCNet_x0_5", + "PP-LCNet_x0_25", + "PP-LCNet_x0_35", + "PP-LCNet_x0_75", + "PP-LCNet_x1_0", + "PP-LCNet_x1_5", + "PP-LCNet_x2_0", + "PP-LCNet_x2_5", + "PP-LCNetV2_base", + "ResNet18_vd", + "ResNet18", + "ResNet34_vd", + "ResNet34", + "ResNet50_vd", + "ResNet50", + "ResNet101_vd", + "ResNet101", + "ResNet152_vd", + "ResNet152", + "ResNet200_vd", + "SwinTransformer_base_patch4_window7_224", + "SwinTransformer_small_patch4_window7_224", + "SwinTransformer_tiny_patch4_window7_224", + "CenterNet-DLA-34", + "CenterNet-ResNet50", + "DETR-R50", + "FasterRCNN-ResNet34-FPN", + "FasterRCNN-ResNet50-FPN", + "FasterRCNN-ResNet50-vd-SSLDv2-FPN", + "FasterRCNN-ResNet101-FPN", + "FCOS-ResNet50", + "PicoDet-L", + "PicoDet-M", + "PicoDet-S", + "PicoDet-XS", + "PP-YOLOE_plus-L", + "PP-YOLOE_plus-M", + "PP-YOLOE_plus-S", + "PP-YOLOE_plus-X", + "RT-DETR-H", + "RT-DETR-L", + "RT-DETR-R18", + "RT-DETR-R50", + "RT-DETR-X", + "YOLOv3-DarkNet53", + "YOLOv3-MobileNetV3", + "YOLOv3-ResNet50_vd_DCN", + "Deeplabv3_Plus-R50", + "Deeplabv3_Plus-R101", + "Deeplabv3-R50", + "Deeplabv3-R101", + "OCRNet_HRNet-W48", + "PP-LiteSeg-T", + "Mask-RT-DETR-H", + "Mask-RT-DETR-L", + "Mask-RT-DETR-M", + "Cascade-MaskRCNN-ResNet50-FPN", + "Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN", + "PP-YOLOE_seg-S", + "PP-OCRv4_mobile_det", + "PP-OCRv4_server_det", + "PP-OCRv4_mobile_rec", + "PP-OCRv4_server_rec", + "ch_SVTRv2_rec", + "ch_RepSVTR_rec", + "SLANet", + "PicoDet_layout_1x", + "DLinear", + "NLinear", + "Nonstationary", + "PatchTST", + "RLinear", + "TiDE", + "TimesNet", + "AutoEncoder_ad", + "DLinear_ad", + "Nonstationary_ad", + "PatchTST_ad", + "TimesNet_ad", + "TimesNet_cls", + ], + "mlu": [ + "MobileNetV3_large_x0_5", + "MobileNetV3_large_x0_35", + "MobileNetV3_large_x0_75", + "MobileNetV3_large_x1_0", + "MobileNetV3_large_x1_25", + "MobileNetV3_small_x0_5", + "MobileNetV3_small_x0_35", + "MobileNetV3_small_x0_75", + "MobileNetV3_small_x1_0", + "MobileNetV3_small_x1_25", + "PP-HGNet_small", + "PP-LCNet_x0_5", + "PP-LCNet_x0_25", + "PP-LCNet_x0_35", + "PP-LCNet_x0_75", + "PP-LCNet_x1_0", + "PP-LCNet_x1_5", + "PP-LCNet_x2_0", + "PP-LCNet_x2_5", + "ResNet18", + "ResNet34", + "ResNet50", + "ResNet101", + "ResNet152", + "PicoDet-L", + "PicoDet-S", + "PP-YOLOE_plus-L", + "PP-YOLOE_plus-M", + "PP-YOLOE_plus-S", + "PP-YOLOE_plus-X", + "PP-LiteSeg-T", + "PP-OCRv4_mobile_det", + "PP-OCRv4_server_det", + "PP-OCRv4_mobile_rec", + "PP-OCRv4_server_rec", + "PicoDet_layout_1x", + "DLinear", + "NLinear", + "RLinear", + ], + "dcu": [ + "ResNet18", + "ResNet34", + "ResNet50", + "ResNet101", + "ResNet152", + "Deeplabv3_Plus-R50", + "Deeplabv3_Plus-R101", + ], +}