diff --git a/ppocr/utils/export_model.py b/ppocr/utils/export_model.py index a62e8109a2..c58e5be5f9 100644 --- a/ppocr/utils/export_model.py +++ b/ppocr/utils/export_model.py @@ -39,42 +39,42 @@ def setup_orderdict(): def dump_infer_config(config, path, logger): setup_orderdict() infer_cfg = OrderedDict() - if config["Global"].get("hpi_config_path", None): - hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r")) - rec_resize_img_dict = next( - ( - item - for item in config["Eval"]["dataset"]["transforms"] - if "RecResizeImg" in item - ), - None, - ) - if rec_resize_img_dict: - dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"] - if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None): - hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][ - "dynamic_shapes" - ]["x"] = [dynamic_shapes for i in range(3)] - hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][ - "max_batch_size" - ] = 1 - if hpi_config["Hpi"]["backend_config"].get("tensorrt", None): - hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][ - "x" - ] = [dynamic_shapes for i in range(3)] - hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1 - else: - if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None): - hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt") - del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"] - if hpi_config["Hpi"]["backend_config"].get("tensorrt", None): - hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt") - del hpi_config["Hpi"]["backend_config"]["tensorrt"] - hpi_config["Hpi"]["selected_backends"]["gpu"] = "paddle_infer" - infer_cfg["Hpi"] = hpi_config["Hpi"] if config["Global"].get("pdx_model_name", None): - infer_cfg["Global"] = {} - infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"] + infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]} + if config["Global"].get("uniform_output_enabled", None): + arch_config = config["Architecture"] + if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]: + common_dynamic_shapes = { + "x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]] + } + elif arch_config["model_type"] == "det": + common_dynamic_shapes = { + "x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]] + } + elif arch_config["algorithm"] == "SLANet": + common_dynamic_shapes = { + "x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]] + } + elif arch_config["algorithm"] == "LaTeXOCR": + common_dynamic_shapes = { + "x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]] + } + else: + common_dynamic_shapes = None + + backend_keys = ["paddle_infer", "tensorrt"] + hpi_config = { + "backend_configs": { + key: { + ( + "dynamic_shapes" if key == "tensorrt" else "trt_dynamic_shapes" + ): common_dynamic_shapes + } + for key in backend_keys + } + } + if common_dynamic_shapes: + infer_cfg["Hpi"] = hpi_config infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]} postprocess = OrderedDict() @@ -96,10 +96,8 @@ def dump_infer_config(config, path, logger): infer_cfg["PostProcess"] = postprocess - with open(path, "w") as f: - yaml.dump( - infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True - ) + with open(path, "w", encoding="utf-8") as f: + yaml.dump(infer_cfg, f, default_flow_style=False, allow_unicode=True) logger.info("Export inference config file to {}".format(os.path.join(path)))