diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index afd7c6ad97..2be2c0de39 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -273,7 +273,7 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num= assert last_num >= 1 train_results_path = os.path.join( - config["Global"]["save_model_dir"], "train_results.json" + config["Global"]["save_model_dir"], "train_result.json" ) save_model_tag = ["pdparams", "pdopt", "pdstates"] save_inference_tag = ["inference_config", "pdmodel", "pdiparams", "pdiparams.info"] diff --git a/tools/train.py b/tools/train.py index 7480b2bd77..0a2e2e6458 100755 --- a/tools/train.py +++ b/tools/train.py @@ -172,11 +172,11 @@ def main(config, device, logger, vdl_writer, seed): amp_custom_black_list = config["Global"].get("amp_custom_black_list", []) amp_custom_white_list = config["Global"].get("amp_custom_white_list", []) if os.path.exists( - os.path.join(config["Global"]["save_model_dir"], "train_results.json") + os.path.join(config["Global"]["save_model_dir"], "train_result.json") ): try: os.remove( - os.path.join(config["Global"]["save_model_dir"], "train_results.json") + os.path.join(config["Global"]["save_model_dir"], "train_result.json") ) except: pass