diff --git a/nbdt/bin/nbdt b/nbdt/bin/nbdt index d559cad..8027993 100644 --- a/nbdt/bin/nbdt +++ b/nbdt/bin/nbdt @@ -3,16 +3,12 @@ from nbdt.model import SoftNBDT from pytorchcv.models.wrn_cifar import wrn28_10_cifar10 -from PIL import Image -from urllib.request import urlopen, Request from torchvision import transforms -import io +from nbdt.utils import DATASET_TO_CLASSES import sys assert len(sys.argv) > 1, "Need to pass image URL or image path as argument" -classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] - # load pretrained NBDT model = wrn28_10_cifar10() model = SoftNBDT( @@ -22,27 +18,17 @@ model = SoftNBDT( pretrained=True, arch='wrn28_10_cifar10') -# load image -path = sys.argv[1] -headers = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3' -} -if 'http' in path: - request = Request(path, headers=headers) - file = io.BytesIO(urlopen(request).read()) -else: - file = path -im = Image.open(file) - -# transform image +# load + transform image +im = load_image_from_path(sys.argv[1]) transforms = transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), - transforms.ToTensor() + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) x = transforms(im)[None] # run inference outputs = model(x) -cls = classes[outputs[0]] +cls = DATASET_TO_CLASSES['CIFAR10'][outputs[0]] print(cls) diff --git a/nbdt/utils.py b/nbdt/utils.py index 3b4ad34..0e4e699 100644 --- a/nbdt/utils.py +++ b/nbdt/utils.py @@ -9,9 +9,12 @@ import math import numpy as np +from urllib.request import urlopen, Request +from PIL import Image import torch.nn as nn import torch.nn.init as init from pathlib import Path +import io # tree-generation consntants METHODS = ('wordnet', 'random', 'induced') @@ -22,6 +25,12 @@ 'TinyImagenet200': 200, 'Imagenet1000': 1000 } +DATASET_TO_CLASSES = { + 'CIFAR10': [ + 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', + 'horse', 'ship', 'truck' + ] +} def fwd(): @@ -61,6 +70,19 @@ def populate_kwargs(args, kwargs, object, name='Dataset', keys=(), globals={}): f'{key}: {value}') +def load_image_from_path(path): + """Path can be local or a URL""" + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3' + } + if 'http' in path: + request = Request(path, headers=headers) + file = io.BytesIO(urlopen(request).read()) + else: + file = path + return Image.open(file) + + class Colors: RED = '\x1b[31m' GREEN = '\x1b[32m'