From 5a8984ceb3edff04c50200e82499f9b3da4da69b Mon Sep 17 00:00:00 2001 From: Joseph Paul Cohen Date: Wed, 11 Sep 2024 00:07:55 -0700 Subject: [PATCH] add cache_dir arg to allow files to be stored anywhere (#159) * add cache_dir to a few models --- .github/workflows/tests.yml | 4 +- torchxrayvision/autoencoders.py | 35 ++++++----------- .../baseline_models/chestx_det/__init__.py | 12 +++++- torchxrayvision/models.py | 39 +++++++------------ torchxrayvision/utils.py | 4 ++ 5 files changed, 41 insertions(+), 53 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f9b300c..cb16ab1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,8 +31,8 @@ jobs: strategy: max-parallel: 2 matrix: - python-version: ['3.9'] - torch-version: [2.1.1] + python-version: ['3.11'] + torch-version: [2.4.1] os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest] # Steps represent a sequence of tasks that will be executed as part of the job diff --git a/torchxrayvision/autoencoders.py b/torchxrayvision/autoencoders.py index a03e1ee..efc01ce 100644 --- a/torchxrayvision/autoencoders.py +++ b/torchxrayvision/autoencoders.py @@ -4,6 +4,7 @@ import os import sys import requests +from . import utils model_urls = {} @@ -218,7 +219,7 @@ def ResNetAE101(**kwargs): return _ResNetAE(Bottleneck, DeconvBottleneck, [3, 4, 23, 2], 1, **kwargs) -def ResNetAE(weights=None): +def ResNetAE(weights=None, cache_dir=None): """A ResNet based autoencoder. Possible weights for this class include: @@ -231,6 +232,11 @@ def ResNetAE(weights=None): z = ae.encode(image) image2 = ae.decode(z) + + params: + weights (str): Weights to use. See above for options. + cache_dir (str): Override directory used to store cached weights (default: ~/.torchxrayvision/) + """ if weights == None: @@ -245,14 +251,17 @@ def ResNetAE(weights=None): # load pretrained models url = model_urls[weights]["weights_url"] weights_filename = os.path.basename(url) - weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) + if cache_dir is None: + weights_storage_folder = utils.get_cache_dir() + else: + weights_storage_folder = cache_dir weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) if not os.path.isfile(weights_filename_local): print("Downloading weights...") print("If this fails you can run `wget {} -O {}`".format(url, weights_filename_local)) pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True) - download(url, weights_filename_local) + utils.download(url, weights_filename_local) try: state_dict = torch.load(weights_filename_local, map_location='cpu') @@ -268,23 +277,3 @@ def ResNetAE(weights=None): ae.description = model_urls[weights]["description"] return ae - - -# from here https://sumit-ghosh.com/articles/python-download-progress-bar/ -def download(url, filename): - with open(filename, 'wb') as f: - response = requests.get(url, stream=True) - total = response.headers.get('content-length') - - if total is None: - f.write(response.content) - else: - downloaded = 0 - total = int(total) - for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)): - downloaded += len(data) - f.write(data) - done = int(50 * downloaded / total) - sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done))) - sys.stdout.flush() - sys.stdout.write('\n') diff --git a/torchxrayvision/baseline_models/chestx_det/__init__.py b/torchxrayvision/baseline_models/chestx_det/__init__.py index ccfd74f..cee1f21 100644 --- a/torchxrayvision/baseline_models/chestx_det/__init__.py +++ b/torchxrayvision/baseline_models/chestx_det/__init__.py @@ -9,6 +9,7 @@ import torchvision from .ptsemseg.pspnet import pspnet +from ... import utils def _convert_state_dict(state_dict): @@ -51,6 +52,10 @@ class PSPNet(nn.Module): url = {https://arxiv.org/abs/2104.10326}, year = {2021} } + + params: + cache_dir (str): Override directory used to store cached weights (default: ~/.torchxrayvision/) + """ targets: List[str] = [ @@ -62,7 +67,7 @@ class PSPNet(nn.Module): ] """""" - def __init__(self): + def __init__(self, cache_dir:str = None): super(PSPNet, self).__init__() @@ -78,7 +83,10 @@ def __init__(self): url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/pspnet_chestxray_best_model_4.pth" weights_filename = os.path.basename(url) - weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) + if cache_dir is None: + weights_storage_folder = utils.get_cache_dir() + else: + weights_storage_folder = cache_dir self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) if not os.path.isfile(self.weights_filename_local): diff --git a/torchxrayvision/models.py b/torchxrayvision/models.py index dc72bbe..6f26750 100644 --- a/torchxrayvision/models.py +++ b/torchxrayvision/models.py @@ -11,6 +11,7 @@ import numpy as np from collections import OrderedDict from . import datasets +from . import utils import warnings warnings.filterwarnings("ignore") @@ -191,6 +192,7 @@ class DenseNet(nn.Module): model = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch") # MIMIC-CXR (MIT) :param weights: Specify a weight name to load pre-trained weights + :param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/) :param op_threshs: Specify a weight name to load pre-trained weights :param apply_sigmoid: Apply a sigmoid @@ -227,6 +229,7 @@ def __init__(self, num_classes=len(datasets.default_pathologies), in_channels=1, weights=None, + cache_dir=None, op_threshs=None, apply_sigmoid=False ): @@ -291,7 +294,7 @@ def __init__(self, self.register_buffer('op_threshs', op_threshs) if self.weights != None: - self.weights_filename_local = get_weights(weights) + self.weights_filename_local = get_weights(weights, cache_dir) try: savedmodel = torch.load(self.weights_filename_local, map_location='cpu') @@ -355,6 +358,7 @@ class ResNet(nn.Module): model = xrv.models.ResNet(weights="resnet50-res512-all") :param weights: Specify a weight name to load pre-trained weights + :param cache_dir: Override where the weights will be stored (default is ~/.torchxrayvision/) :param op_threshs: Specify a weight name to load pre-trained weights :param apply_sigmoid: Apply a sigmoid @@ -382,7 +386,7 @@ class ResNet(nn.Module): ] """""" - def __init__(self, weights: str = None, apply_sigmoid: bool = False): + def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir: str = None): super(ResNet, self).__init__() self.weights = weights @@ -392,7 +396,7 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False): possible_weights = [k for k in model_urls.keys() if k.startswith("resnet")] raise Exception("Weights value must be in {}".format(possible_weights)) - self.weights_filename_local = get_weights(weights) + self.weights_filename_local = get_weights(weights, cache_dir=cache_dir) self.weights_dict = model_urls[weights] self.targets = model_urls[weights]["labels"] self.pathologies = self.targets # keep to be backward compatible @@ -546,39 +550,22 @@ def get_model(weights: str, **kwargs): raise Exception("Unknown model") -def get_weights(weights: str): +def get_weights(weights: str, cache_dir:str = None): if not weights in model_urls: raise Exception("Weights not found. Valid options: {}".format(list(model_urls.keys()))) url = model_urls[weights]["weights_url"] weights_filename = os.path.basename(url) - weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data")) + if cache_dir is None: + weights_storage_folder = utils.get_cache_dir() + else: + weights_storage_folder = cache_dir weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename)) if not os.path.isfile(weights_filename_local): print("Downloading weights...") print("If this fails you can run `wget {} -O {}`".format(url, weights_filename_local)) pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True) - download(url, weights_filename_local) + utils.download(url, weights_filename_local) return weights_filename_local - - -# from here https://sumit-ghosh.com/articles/python-download-progress-bar/ -def download(url: str, filename: str): - with open(filename, 'wb') as f: - response = requests.get(url, stream=True) - total = response.headers.get('content-length') - - if total is None: - f.write(response.content) - else: - downloaded = 0 - total = int(total) - for data in response.iter_content(chunk_size=max(int(total / 1000), 1024 * 1024)): - downloaded += len(data) - f.write(data) - done = int(50 * downloaded / total) - sys.stdout.write('\r[{}{}]'.format('█' * done, '.' * (50 - done))) - sys.stdout.flush() - sys.stdout.write('\n') diff --git a/torchxrayvision/utils.py b/torchxrayvision/utils.py index e829d22..8e6fa0f 100644 --- a/torchxrayvision/utils.py +++ b/torchxrayvision/utils.py @@ -3,6 +3,7 @@ import numpy as np import skimage import torch +import os from os import PathLike from numpy import ndarray @@ -10,6 +11,9 @@ from tqdm.autonotebook import tqdm +def get_cache_dir(): + return os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data/")) + def in_notebook(): try: from IPython import get_ipython