Skip to content

Commit

Permalink
update hpi config
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Nov 8, 2024
1 parent aae2a3e commit c6d61fa
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,42 @@ def setup_orderdict():
def dump_infer_config(config, path, logger):
setup_orderdict()
infer_cfg = OrderedDict()
if config["Global"].get("hpi_config_path", None):
hpi_config = yaml.safe_load(open(config["Global"]["hpi_config_path"], "r"))
rec_resize_img_dict = next(
(
item
for item in config["Eval"]["dataset"]["transforms"]
if "RecResizeImg" in item
),
None,
)
if rec_resize_img_dict:
dynamic_shapes = [1] + rec_resize_img_dict["RecResizeImg"]["image_shape"]
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
"dynamic_shapes"
]["x"] = [dynamic_shapes for i in range(3)]
hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"][
"max_batch_size"
] = 1
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
hpi_config["Hpi"]["backend_config"]["tensorrt"]["dynamic_shapes"][
"x"
] = [dynamic_shapes for i in range(3)]
hpi_config["Hpi"]["backend_config"]["tensorrt"]["max_batch_size"] = 1
else:
if hpi_config["Hpi"]["backend_config"].get("paddle_tensorrt", None):
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("paddle_tensorrt")
del hpi_config["Hpi"]["backend_config"]["paddle_tensorrt"]
if hpi_config["Hpi"]["backend_config"].get("tensorrt", None):
hpi_config["Hpi"]["supported_backends"]["gpu"].remove("tensorrt")
del hpi_config["Hpi"]["backend_config"]["tensorrt"]
hpi_config["Hpi"]["selected_backends"]["gpu"] = "paddle_infer"
infer_cfg["Hpi"] = hpi_config["Hpi"]
if config["Global"].get("pdx_model_name", None):
infer_cfg["Global"] = {}
infer_cfg["Global"]["model_name"] = config["Global"]["pdx_model_name"]
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
if config["Global"].get("uniform_output_enabled", None):
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
common_dynamic_shapes = {
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
}
elif arch_config["model_type"] == "det":
common_dynamic_shapes = {
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
}
elif arch_config["algorithm"] == "SLANet":
common_dynamic_shapes = {
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
}
elif arch_config["algorithm"] == "LaTeXOCR":
common_dynamic_shapes = {
"x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]]
}
else:
common_dynamic_shapes = None

backend_keys = ["paddle_infer", "tensorrt"]
hpi_config = {
"backend_configs": {
key: {
(
"dynamic_shapes" if key == "tensorrt" else "trt_dynamic_shapes"
): common_dynamic_shapes
}
for key in backend_keys
}
}
if common_dynamic_shapes:
infer_cfg["Hpi"] = hpi_config

infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
postprocess = OrderedDict()
Expand All @@ -96,10 +96,8 @@ def dump_infer_config(config, path, logger):

infer_cfg["PostProcess"] = postprocess

with open(path, "w") as f:
yaml.dump(
infer_cfg, f, default_flow_style=False, encoding="utf-8", allow_unicode=True
)
with open(path, "w", encoding="utf-8") as f:
yaml.dump(infer_cfg, f, default_flow_style=False, allow_unicode=True)
logger.info("Export inference config file to {}".format(os.path.join(path)))


Expand Down

0 comments on commit c6d61fa

Please sign in to comment.