From f33f43b2ab4262ba50bc2f04ca27b9e5e3825800 Mon Sep 17 00:00:00 2001 From: Armand Collin <83031821+hermancollin@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:39:54 -0400 Subject: [PATCH] Add support for best checkpoints + update `download_models` (#12) * Add support for best checkpoints * Update release tag of TEM unmyelinated axon seg model * Add myelin cutter model for download --- download_models.py | 9 +++++++-- nn_axondeepseg.py | 9 ++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/download_models.py b/download_models.py index 9390b98..38d044b 100644 --- a/download_models.py +++ b/download_models.py @@ -14,8 +14,8 @@ from requests.packages.urllib3.util import Retry MODELS = { - 'model_seg_unmyelinated_tem' : { - 'url': 'https://github.com/axondeepseg/model_seg_unmyelinated_tem/releases/download/v1.0.0/model_seg_unmyelinated_tem.zip', + 'model_seg_unmyelinated_sickkids_tem_best' : { + 'url': 'https://github.com/axondeepseg/model_seg_unmyelinated_tem/releases/download/v1.1.0/model_seg_unmyelinated_sickkids_tem_best.zip', 'description': 'Unmyelinated axon segmentation (1-class)', 'contrasts': ['TEM'], }, @@ -24,6 +24,11 @@ 'description': 'Axon and myelin segmentation on Toluidine Blue stained BF images (rabbit)', 'contrasts': ['BF'], }, + 'model_myelin_cutter': { + 'url': 'https://github.com/axondeepseg/model_postprocess_touching_myelin/releases/download/r20240313/model_myelin_cutter.zip', + 'description': 'Postprocessing myelin cutter model to delineate touching myelin. To use directly on an axonmyelin segmentation mask.', + 'contrasts': ['any'], + } } diff --git a/nn_axondeepseg.py b/nn_axondeepseg.py index 3e9c6b9..8542b8a 100644 --- a/nn_axondeepseg.py +++ b/nn_axondeepseg.py @@ -37,6 +37,8 @@ def get_parser(): parser.add_argument('--path-out', help='Path to output directory.', required=True) parser.add_argument('--path-model', default=None, help='Path to the model directory. This folder should contain individual folders like fold_0, fold_1, etc.',) + parser.add_argument('--use-best', action='store_true', default=False, + help='Use the best checkpoints instead of the final ones. Default: False') parser.add_argument('--use-gpu', action='store_true', default=False, help='Use GPU for inference. Default: False') return parser @@ -90,7 +92,12 @@ def main(): ) logger.info('Running inference on device: {}'.format(predictor.device)) # initialize network architecture, load checkpoint - predictor.initialize_from_trained_model_folder(path_model, use_folds=None) + checkpoint_name = 'checkpoint_final.pth' if not args.use_best else 'checkpoint_best.pth' + predictor.initialize_from_trained_model_folder( + path_model, + use_folds=None, + checkpoint_name=checkpoint_name + ) logger.info('Model loaded successfully.')