diff --git a/README.md b/README.md index 9c410a9..1aa7278 100644 --- a/README.md +++ b/README.md @@ -85,11 +85,11 @@ For RGB-Depth semantic segmentation, the generation of HHA maps from Depth maps 2. Config - Edit config file in `configs.py`, including dataset and network settings. + Edit config file in configs directory, including dataset and network settings. 3. Run multi GPU distributed training: ```shell - $ CUDA_VISIBLE_DEVICES="GPU IDs" python -m torch.distributed.launch --nproc_per_node="GPU numbers you want to use" train.py + $ CUDA_VISIBLE_DEVICES="GPU IDs" python -m torch.distributed.launch --nproc_per_node="GPU numbers you want to use" train.py -f "config_file" ``` - The tensorboard file is saved in `log__/tb/` directory. @@ -98,7 +98,7 @@ For RGB-Depth semantic segmentation, the generation of HHA maps from Depth maps ### Evaluation Run the evaluation by: ```shell -CUDA_VISIBLE_DEVICES="GPU IDs" python eval.py -d="Device ID" -e="epoch number or range" +CUDA_VISIBLE_DEVICES="GPU IDs" python eval.py -f "config_file" -d="Device ID" -e="epoch number or range" ``` If you want to use multi GPUs please specify multiple Device IDs (0,1,2...). diff --git a/config.py b/configs/NYUDepthv2/NYUDepthv2_mit_b2_e500_v1.py similarity index 82% rename from config.py rename to configs/NYUDepthv2/NYUDepthv2_mit_b2_e500_v1.py index 9295ecc..f6f95fd 100644 --- a/config.py +++ b/configs/NYUDepthv2/NYUDepthv2_mit_b2_e500_v1.py @@ -17,7 +17,6 @@ C.abs_dir = osp.realpath(".") # Dataset config -"""Dataset Path""" C.dataset_name = 'NYUDepthv2' C.dataset_path = osp.join(C.root_dir, 'datasets', 'NYUDepthv2') C.rgb_root_folder = osp.join(C.dataset_path, 'RGB') @@ -29,14 +28,14 @@ # True for most dataset valid, Faslse for MFNet(?) C.x_root_folder = osp.join(C.dataset_path, 'HHA') C.x_format = '.jpg' -C.x_is_single_channel = False # True for raw depth, thermal and aolp/dolp(not aolp/dolp tri) input +C.x_is_single_channel = False # True for raw depth, thermal and aolp/dolp(not aolp/dolp tri) input C.train_source = osp.join(C.dataset_path, "train.txt") C.eval_source = osp.join(C.dataset_path, "test.txt") C.is_test = False C.num_train_imgs = 795 C.num_eval_imgs = 654 C.num_classes = 40 -C.class_names = ['wall','floor','cabinet','bed','chair','sofa','table','door','window','bookshelf','picture','counter','blinds', +C.class_names = ['wall','floor','cabinet','bed','chair','sofa','table','door','window','bookshelf','picture','counter','blinds', 'desk','shelves','curtain','dresser','pillow','mirror','floor mat','clothes','ceiling','books','refridgerator', 'television','paper','towel','shower curtain','box','whiteboard','person','night stand','toilet', 'sink','lamp','bathtub','bag','otherstructure','otherfurniture','otherprop'] @@ -49,7 +48,7 @@ C.norm_std = np.array([0.229, 0.224, 0.225]) """ Settings for network, this would be different for each kind of model""" -C.backbone = 'mit_b2' # Remember change the path below. +C.backbone = 'mit_b2' # Remember change the path below. C.pretrained_model = C.root_dir + '/pretrained/segformer/mit_b2.pth' C.decoder = 'MLPDecoder' C.decoder_embed_dim = 512 @@ -62,7 +61,7 @@ C.weight_decay = 0.01 C.batch_size = 8 C.nepochs = 500 -C.niters_per_epoch = C.num_train_imgs // C.batch_size + 1 +C.niters_per_epoch = C.num_train_imgs // C.batch_size + 1 C.num_workers = 16 C.train_scale_array = [0.5, 0.75, 1, 1.25, 1.5, 1.75] C.warm_up_epoch = 10 @@ -74,9 +73,9 @@ """Eval Config""" C.eval_iter = 25 C.eval_stride_rate = 2 / 3 -C.eval_scale_array = [1] # [0.75, 1, 1.25] # -C.eval_flip = False # True # -C.eval_crop_size = [480, 640] # [height weight] +C.eval_scale_array = [1] # [0.75, 1, 1.25] # +C.eval_flip = False # True # +C.eval_crop_size = [480, 640] # [height weight] """Store Config""" C.checkpoint_start_epoch = 250 @@ -88,7 +87,8 @@ def add_path(path): sys.path.insert(0, path) add_path(osp.join(C.root_dir)) -C.log_dir = osp.abspath('log_' + C.dataset_name + '_' + C.backbone) +config_name=os.path.basename(__file__).split(".")[0] +C.log_dir = osp.abspath('logs/log_' + config_name) C.tb_dir = osp.abspath(osp.join(C.log_dir, "tb")) C.log_dir_link = C.log_dir C.checkpoint_dir = osp.abspath(osp.join(C.log_dir, "checkpoint")) @@ -106,5 +106,5 @@ def add_path(path): '-tb', '--tensorboard', default=False, action='store_true') args = parser.parse_args() - if args.tensorboard: - open_tensorboard() \ No newline at end of file + # if args.tensorboard: + # open_tensorboard() \ No newline at end of file diff --git a/dataloader/RGBXDataset.py b/dataloader/RGBXDataset.py index ee16656..671eb4c 100644 --- a/dataloader/RGBXDataset.py +++ b/dataloader/RGBXDataset.py @@ -51,7 +51,7 @@ def __getitem__(self, index): x = self._open_image(x_path, cv2.IMREAD_GRAYSCALE) x = cv2.merge([x, x, x]) else: - x = self._open_image(x_path, cv2.COLOR_BGR2RGB) + x = self._open_image(x_path, cv2.COLOR_BGR2RGB) if self.preprocess is not None: rgb, gt, x = self.preprocess(rgb, gt, x) diff --git a/dataloader/dataloader.py b/dataloader/dataloader.py index e264080..baad4bf 100644 --- a/dataloader/dataloader.py +++ b/dataloader/dataloader.py @@ -3,9 +3,10 @@ import numpy as np from torch.utils import data import random -from config import config +# from config import config from utils.transforms import generate_random_crop_pos, random_crop_pad_to_shape, normalize + def random_mirror(rgb, gt, modal_x): if random.random() >= 0.5: rgb = cv2.flip(rgb, 1) @@ -14,6 +15,7 @@ def random_mirror(rgb, gt, modal_x): return rgb, gt, modal_x + def random_scale(rgb, gt, modal_x, scales): scale = random.choice(scales) sh = int(rgb.shape[0] * scale) @@ -24,20 +26,22 @@ def random_scale(rgb, gt, modal_x, scales): return rgb, gt, modal_x, scale + class TrainPre(object): - def __init__(self, norm_mean, norm_std): - self.norm_mean = norm_mean - self.norm_std = norm_std + def __init__(self, config): + self.config = config + self.norm_mean = config.norm_mean + self.norm_std = config.norm_std def __call__(self, rgb, gt, modal_x): rgb, gt, modal_x = random_mirror(rgb, gt, modal_x) - if config.train_scale_array is not None: - rgb, gt, modal_x, scale = random_scale(rgb, gt, modal_x, config.train_scale_array) + if self.config.train_scale_array is not None: + rgb, gt, modal_x, scale = random_scale(rgb, gt, modal_x, self.config.train_scale_array) rgb = normalize(rgb, self.norm_mean, self.norm_std) modal_x = normalize(modal_x, self.norm_mean, self.norm_std) - crop_size = (config.image_height, config.image_width) + crop_size = (self.config.image_height, self.config.image_width) crop_pos = generate_random_crop_pos(rgb.shape[:2], crop_size) p_rgb, _ = random_crop_pad_to_shape(rgb, crop_pos, crop_size, 0) @@ -49,11 +53,13 @@ def __call__(self, rgb, gt, modal_x): return p_rgb, p_gt, p_modal_x + class ValPre(object): def __call__(self, rgb, gt, modal_x): return rgb, gt, modal_x -def get_train_loader(engine, dataset): + +def get_train_loader(engine, dataset, config=None): data_setting = {'rgb_root': config.rgb_root_folder, 'rgb_format': config.rgb_format, 'gt_root': config.gt_root_folder, @@ -66,7 +72,7 @@ def get_train_loader(engine, dataset): 'train_source': config.train_source, 'eval_source': config.eval_source, 'class_names': config.class_names} - train_preprocess = TrainPre(config.norm_mean, config.norm_std) + train_preprocess = TrainPre(config) train_dataset = dataset(data_setting, "train", train_preprocess, config.batch_size * config.niters_per_epoch) diff --git a/engine/engine.py b/engine/engine.py index 8fd3b0b..fe62096 100644 --- a/engine/engine.py +++ b/engine/engine.py @@ -61,6 +61,8 @@ def __init__(self, custom_parser=None): def inject_default_parser(self): p = self.parser + p.add_argument("-f", "--config_file", default=None, type=str, + help="plz input your experiment description file", ) p.add_argument('-d', '--devices', default='', help='set data parallel training') p.add_argument('-c', '--continue', type=extant_file, diff --git a/engine/evaluator.py b/engine/evaluator.py index 60ba1e6..7f97af0 100644 --- a/engine/evaluator.py +++ b/engine/evaluator.py @@ -50,29 +50,34 @@ def run(self, model_path, model_indice, log_file, log_file_link): models = [model_indice, ] elif "-" in model_indice: start_epoch = int(model_indice.split("-")[0]) - end_epoch = model_indice.split("-")[1] - - models = os.listdir(model_path) - models.remove("epoch-last.pth") - sorted_models = [None] * len(models) - model_idx = [0] * len(models) - - for idx, m in enumerate(models): - num = m.split(".")[0].split("-")[1] - model_idx[idx] = num - sorted_models[idx] = m - model_idx = np.array([int(i) for i in model_idx]) - - down_bound = model_idx >= start_epoch - up_bound = [True] * len(sorted_models) - if end_epoch: - end_epoch = int(end_epoch) - assert start_epoch < end_epoch - up_bound = model_idx <= end_epoch - bound = up_bound * down_bound - model_slice = np.array(sorted_models)[bound] - models = [os.path.join(model_path, model) for model in - model_slice] + end_epoch = int(model_indice.split("-")[1]) + + models = [] + for i in range(start_epoch, end_epoch + 1): + models.append(os.path.join(model_path, 'epoch-%s.pth' % i)) + + # models = os.listdir(model_path) + # models.remove("epoch-last.pth") + # sorted_models = [None] * len(models) + # model_idx = [0] * len(models) + # + # for idx, m in enumerate(models): + # num = m.split(".")[0].split("-")[1] + # model_idx[idx] = num + # sorted_models[idx] = m + # model_idx = np.array([int(i) for i in model_idx]) + # model_idx.sort() + # + # down_bound = model_idx >= start_epoch + # up_bound = [True] * len(sorted_models) + # if end_epoch: + # end_epoch = int(end_epoch) + # assert start_epoch < end_epoch + # up_bound = model_idx <= end_epoch + # bound = up_bound * down_bound + # model_slice = np.array(sorted_models)[bound] + # models = [os.path.join(model_path, model) for model in + # model_slice] else: if os.path.exists(model_path): models = [os.path.join(model_path, 'epoch-%s.pth' % model_indice), ] @@ -306,7 +311,10 @@ def process_image(self, img, crop_size=None): def sliding_eval_rgbX(self, img, modal_x, crop_size, stride_rate, device=None): crop_size = to_2tuple(crop_size) ori_rows, ori_cols, _ = img.shape - processed_pred = np.zeros((ori_rows, ori_cols, self.class_num)) + if self.class_num < 2: + processed_pred = np.zeros((ori_rows, ori_cols)) + else: + processed_pred = np.zeros((ori_rows, ori_cols, self.class_num)) for s in self.multi_scales: img_scale = cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_LINEAR) @@ -319,7 +327,8 @@ def sliding_eval_rgbX(self, img, modal_x, crop_size, stride_rate, device=None): processed_pred += self.scale_process_rgbX(img_scale, modal_x_scale, (ori_rows, ori_cols), crop_size, stride_rate, device) - pred = processed_pred.argmax(2) + if self.class_num > 1: + pred = processed_pred.argmax(2) return pred diff --git a/eval.py b/eval.py index b00b131..26c6b62 100644 --- a/eval.py +++ b/eval.py @@ -6,18 +6,21 @@ import torch import torch.nn as nn -from config import config +# from config import config +from utils.config_utils import get_config_by_file from utils.pyt_utils import ensure_dir, link_file, load_model, parse_devices -from utils.visualize import print_iou, show_img +from utils.visualize import print_iou, show_img, get_class_colors from engine.evaluator import Evaluator from engine.logger import get_logger from utils.metric import hist_info, compute_score from dataloader.RGBXDataset import RGBXDataset from models.builder import EncoderDecoder as segmodel from dataloader.dataloader import ValPre +from PIL import Image logger = get_logger() + class SegEvaluator(Evaluator): def func_per_iteration(self, data, device): img = data['data'] @@ -36,7 +39,7 @@ def func_per_iteration(self, data, device): # save colored result result_img = Image.fromarray(pred.astype(np.uint8), mode='P') - class_colors = get_class_colors() + class_colors = get_class_colors(config.num_classes) palette_list = list(np.array(class_colors).flat) if len(palette_list) < 768: palette_list += [0] * (768 - len(palette_list)) @@ -75,8 +78,11 @@ def compute_metric(self, results): dataset.class_names, show_no_back=False) return result_line + if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("-f", "--config_file", default=None, type=str, + help="plz input your experiment description file",) parser.add_argument('-e', '--epochs', default='last', type=str) parser.add_argument('-d', '--devices', default='0', type=str) parser.add_argument('-v', '--verbose', default=False, action='store_true') @@ -86,6 +92,7 @@ def compute_metric(self, results): args = parser.parse_args() all_dev = parse_devices(args.devices) + config = get_config_by_file(args.config_file) network = segmodel(cfg=config, criterion=None, norm_layer=nn.BatchNorm2d) data_setting = {'rgb_root': config.rgb_root_folder, diff --git a/models/builder.py b/models/builder.py index 40a8dc9..962b35b 100644 --- a/models/builder.py +++ b/models/builder.py @@ -47,6 +47,10 @@ def __init__(self, cfg=None, criterion=nn.CrossEntropyLoss(reduction='mean', ign self.channels = [32, 64, 160, 256] from .encoders.dual_segformer import mit_b0 as backbone self.backbone = backbone(norm_fuse=norm_layer) + elif cfg.backbone == 'mit_b2_s': + logger.info('Using backbone: Segformer-B2') + from .encoders.segformer import mit_b2 as backbone + self.backbone = backbone(norm_fuse=norm_layer) else: logger.info('Using backbone: Segformer-B2') from .encoders.dual_segformer import mit_b2 as backbone diff --git a/train.py b/train.py index 60fb241..582687c 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,7 @@ import os.path as osp import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '3' +# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' import sys import time import argparse @@ -11,7 +13,8 @@ import torch.backends.cudnn as cudnn from torch.nn.parallel import DistributedDataParallel -from config import config +# from config import config +from utils.config_utils import get_config_by_file from dataloader.dataloader import get_train_loader from models.builder import EncoderDecoder as segmodel from dataloader.RGBXDataset import RGBXDataset @@ -20,6 +23,7 @@ from engine.engine import Engine from engine.logger import get_logger from utils.pyt_utils import all_reduce_tensor +from utils.losses import BCEDiceLoss from tensorboardX import SummaryWriter @@ -31,6 +35,7 @@ with Engine(custom_parser=parser) as engine: args = parser.parse_args() + config = get_config_by_file(args.config_file) cudnn.benchmark = True seed = config.seed if engine.distributed: @@ -40,7 +45,7 @@ torch.cuda.manual_seed(seed) # data loader - train_loader, train_sampler = get_train_loader(engine, RGBXDataset) + train_loader, train_sampler = get_train_loader(engine, RGBXDataset, config) if (engine.distributed and (engine.local_rank == 0)) or (not engine.distributed): tb_dir = config.tb_dir + '/{}'.format(time.strftime("%b%d_%d-%H-%M", time.localtime())) @@ -49,14 +54,17 @@ engine.link_tb(tb_dir, generate_tb_dir) # config network and criterion - criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=config.background) + if config.num_classes > 2: + criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=config.background) + else: + criterion = BCEDiceLoss() if engine.distributed: BatchNorm2d = nn.SyncBatchNorm else: BatchNorm2d = nn.BatchNorm2d - model=segmodel(cfg=config, criterion=criterion, norm_layer=BatchNorm2d) + model = segmodel(cfg=config, criterion=criterion, norm_layer=BatchNorm2d) # group weight and config optimizer base_lr = config.lr @@ -114,6 +122,11 @@ gts = minibatch['label'] modal_xs = minibatch['modal_x'] + # gts = torch.unsqueeze(gts, axis=1) + # print(gts.dtype) + # gts = gts.to(torch.float) + # print(gts.dtype) + imgs = imgs.cuda(non_blocking=True) gts = gts.cuda(non_blocking=True) modal_xs = modal_xs.cuda(non_blocking=True) diff --git a/utils/__init__.py b/utils/__init__.py index e69de29..0eb1860 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -0,0 +1 @@ +# from . import (losses, functional, base) \ No newline at end of file diff --git a/utils/_modules.py b/utils/_modules.py new file mode 100644 index 0000000..b274045 --- /dev/null +++ b/utils/_modules.py @@ -0,0 +1,23 @@ +import torch.nn as nn + + +class Activation(nn.Module): + + def __init__(self, name, **params): + super().__init__() + + if name is None or name == 'identity': + self.activation = nn.Identity(**params) + elif name == 'sigmoid': + self.activation = nn.Sigmoid() + elif name == 'softmax': + self.activation = nn.Softmax(**params) + elif name == 'logsoftmax': + self.activation = nn.LogSoftmax(**params) + elif callable(name): + self.activation = name(**params) + else: + raise ValueError('Activation should be callable/sigmoid/softamx/logsoftmax/None; got {}'.format(name)) + + def forward(self, x): + return self.activation(x) diff --git a/utils/base.py b/utils/base.py new file mode 100644 index 0000000..2d35ac0 --- /dev/null +++ b/utils/base.py @@ -0,0 +1,72 @@ +import re +import torch.nn as nn + + +class BaseObject(nn.Module): + + def __init__(self, name=None): + super().__init__() + self._name = name + + @property + def __name__(self): + if self._name is None: + name = self.__class__.__name__ + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + else: + return self._name + + +class Metric(BaseObject): + pass + + +class Loss(BaseObject): + + def __add__(self, other): + if isinstance(other, Loss): + return SumOfLosses(self, other) + else: + raise ValueError('Loss should be inherited from `Loss` class') + + def __radd__(self, other): + return self.__add__(other) + + def __mul__(self, value): + if isinstance(value, (int, float)): + return MultipliedLoss(self, value) + else: + raise ValueError('Loss should be inherited from `BaseLoss` class') + + def __rmul__(self, other): + return self.__mul__(other) + + +class SumOfLosses(Loss): + + def __init__(self, l1, l2): + name = '{} + {}'.format(l1.__name__, l2.__name__) + super().__init__(name=name) + self.l1 = l1 + self.l2 = l2 + + def forward(self, *inputs): + return self.l1(*inputs) + self.l2(*inputs) + + +class MultipliedLoss(Loss): + + def __init__(self, loss, multiplier): + + # resolve name + if len(loss.__name__.split('+')) > 1: + name = '{} * ({})'.format(multiplier, loss.__name__) + else: + name = '{} * {}'.format(multiplier, loss.__name__) + super().__init__(name=name) + self.loss = loss + self.multiplier = multiplier + + def forward(self, *inputs): + return self.multiplier * self.loss(*inputs) diff --git a/utils/config_utils.py b/utils/config_utils.py new file mode 100644 index 0000000..c919e4f --- /dev/null +++ b/utils/config_utils.py @@ -0,0 +1,13 @@ +import importlib +import os +import sys + + +def get_config_by_file(config_file): + try: + sys.path.append(os.path.dirname(config_file)) + current_config = importlib.import_module(os.path.basename(config_file).split(".")[0]) + config = current_config.config + except Exception: + raise ImportError("{} doesn't contains class named 'Exp'".format(config_file)) + return config \ No newline at end of file diff --git a/utils/functional.py b/utils/functional.py new file mode 100644 index 0000000..9ce1805 --- /dev/null +++ b/utils/functional.py @@ -0,0 +1,201 @@ +import torch + + +def _ignore_channels(*xs, ignore_channels=None): + if ignore_channels is None: + return xs + else: + channels = [channel for channel in range(xs[0].shape[1]) if channel not in ignore_channels] + xs = [torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) for x in xs] + return xs + + +def _take_channels(*xs, take_channels=None): + if take_channels is None: + return xs + else: + xs = [torch.index_select(x, dim=1, index=torch.tensor(take_channels).to(x.device)) for x in xs] + return xs + + +def _take_non_empty(pr, gt, drop_empty=True): + if drop_empty: + mask = gt[gt.sum(dim=(1, 2, 3)) > 0] + gt = gt[mask] + pr = pr[mask] + return pr, gt + + +def _threshold(x, threshold=None): + if threshold is not None: + return (x > threshold).type(x.dtype) + else: + return x + + +def _sum(x, per_image=False): + if per_image: + return torch.sum(x, dim=(2, 3)) + else: + return torch.sum(x, dim=(0, 2, 3)) + + +def _average(x, weights=None): + """""" + if weights is not None: + x = x * torch.tensor(weights, dtype=x.dtype, requires_grad=False).to(x.device) + + if x.dim() == 2: + x = x.mean(dim=0) + + return x.mean() + + +def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None, + class_weights=None, per_image=False, drop_empty=False, take_channels=None): + """Calculate Intersection over Union between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: IoU (Jaccard) score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _ignore_channels(pr, gt, ignore_channels=ignore_channels) + pr, gt = _take_channels(pr, gt, take_channels=take_channels) + #pr, gt = _take_non_empty(pr, gt, drop_empty=drop_empty) + + # if gt.nelement() == 0: + # return 1. + + intersection = _sum(gt * pr, per_image) + union = _sum(gt, per_image) + _sum(pr, per_image) - intersection + score = (intersection + eps) / (union + eps) + + if drop_empty: + agg_mask = gt.sum(dim=(2, 3)) if per_image else gt.sum(dim=(0, 2, 3)) + empty_mask = 1. - (agg_mask > 1).float() + score = score * empty_mask + + return _average(score, class_weights) + + +jaccard = iou + + +def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None, + class_weights=None, per_image=False, drop_empty=False, take_channels=None): + """Calculate F-score between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + beta (float): positive constant + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: F score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _ignore_channels(pr, gt, ignore_channels=ignore_channels) + pr, gt = _take_channels(pr, gt, take_channels=take_channels) + + if drop_empty: + pr = pr * (gt.sum(dim=(2, 3), keepdims=True) > 0).float() + + tp = _sum(gt * pr, per_image) + fp = _sum(pr, per_image) - tp + fn = _sum(gt, per_image) - tp + + score = ((1 + beta ** 2) * tp + eps) \ + / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps) + + # if drop_empty: + # agg_mask = gt.sum(dim=(2, 3)) if per_image else gt.sum(dim=(0, 2, 3)) + # non_empty_mask = (agg_mask > 1).float() + # score = score * non_empty_mask + + return _average(score, class_weights) + + +def accuracy(pr, gt, threshold=0.5, ignore_channels=None, take_channels=None): + """Calculate accuracy score between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: precision score + """ + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, take_channels=take_channels) + pr, gt = _ignore_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt == pr).float() + score = tp / gt.view(-1).shape[0] + return score + + +def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate precision score between ground truth and prediction + Args: + pr (torch.Tensor): predicted tensor + gt (torch.Tensor): ground truth tensor + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: precision score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _ignore_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fp = torch.sum(pr) - tp + + score = (tp + eps) / (tp + fp + eps) + + return score + + +def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate Recall between ground truth and prediction + Args: + pr (torch.Tensor): A list of predicted elements + gt (torch.Tensor): A list of elements that are to be predicted + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: recall score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _ignore_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fn = torch.sum(gt) - tp + + score = (tp + eps) / (tp + fn + eps) + + return score + + +# def binary_crossentropy(pr, gt, eps=1e-7, pos_weight=1., neg_weight=1.): +# pr = torch.clamp(pr, eps, 1. - eps) +# gt = torch.clamp(gt, eps, 1. - eps) +# loss = - pos_weight * gt * pr.log() - neg_weight * (1. - gt) * (1. - pr).log() +# return loss + + +def binary_crossentropy(pr, gt, eps=1e-7, pos_weight=1., neg_weight=1., label_smoothing=None): + if label_smoothing is not None: + label_smoothing = torch.tensor(label_smoothing).to(gt.device) + gt = gt * (1. - label_smoothing) + (1. - gt) * label_smoothing + pr = torch.clamp(pr, eps, 1. - eps) + gt = torch.clamp(gt, eps, 1. - eps) + loss = - pos_weight * gt * torch.log(pr/gt) - neg_weight * (1. - gt) * torch.log((1. - pr) / (1. - gt)) + return loss diff --git a/utils/losses.py b/utils/losses.py new file mode 100644 index 0000000..08a3d1b --- /dev/null +++ b/utils/losses.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn + +from . import base +from . import functional as F +from . import _modules as modules + + +class JaccardLoss(base.Loss): + + def __init__(self, eps=1e-7, activation=None, ignore_channels=None, + per_image=False, class_weights=None, **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.activation = modules.Activation(activation, dim=1) + self.per_image = per_image + self.ignore_channels = ignore_channels + self.class_weights = class_weights + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return 1 - F.jaccard( + y_pr, y_gt, + eps=self.eps, + threshold=None, + ignore_channels=self.ignore_channels, + per_image=self.per_image, + class_weights=self.class_weights, + ) + + +class DiceLoss(base.Loss): + + def __init__(self, eps=1e-7, beta=1., activation=None, ignore_channels=None, + per_image=False, class_weights=None, drop_empty=False, + # smoothing=0., + aux_loss_weight=0, aux_loss_thres=50, **kwargs): # TODO add for more loss + super().__init__(**kwargs) + self.eps = eps + self.beta = beta + self.activation = modules.Activation(activation, dim=1) + self.ignore_channels = ignore_channels + self.per_image = per_image + self.class_weights = class_weights + self.drop_empty = drop_empty + # + # self.smoothing = smoothing + self.aux_loss_weight = aux_loss_weight + self.aux_loss_thres = aux_loss_thres + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + + # + if self.aux_loss_weight > 0: + gt = torch.sum(y_gt, axis=[2, 3], keepdim=True) + gt = (gt > self.aux_loss_thres).type(gt.dtype) + + if y_pr.shape[1] > 1: + pr = torch.argmax(y_pr, axis=1, keepdim=True) + else: + pr = (y_pr > 0.5) + pr = pr.type(y_pr.dtype) + pr = torch.sum(pr, axis=[2, 3], keepdim=True) + pr = (pr > self.aux_loss_thres).type(pr.dtype) + + class_loss = F.binary_crossentropy( + pr, gt, + # pos_weight=self.pos_weight, + # neg_weight=self.neg_weight, + # label_smoothing=self.label_smoothing, + ) + class_loss = class_loss.mean() + + # if self.smoothing > 0: # dice loss not smooth + # y_gt = y_gt * (1-self.smoothing) + # y_gt = y_gt + self.smoothing / y_gt.shape[1] + + dice_loss = 1 - F.f_score( + y_pr, y_gt, + beta=self.beta, + eps=self.eps, + threshold=None, + ignore_channels=self.ignore_channels, + per_image=self.per_image, + class_weights=self.class_weights, + drop_empty=self.drop_empty, + ) + + if self.aux_loss_weight > 0: + return dice_loss * (1 - self.aux_loss_weight) + class_loss * self.aux_loss_weight + + return dice_loss + + +class L1Loss(nn.L1Loss, base.Loss): + pass + + +class MSELoss(nn.MSELoss, base.Loss): + pass + + +class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss): + pass + + +class NLLLoss(nn.NLLLoss, base.Loss): + pass + + +class BCELoss(base.Loss): + + def __init__(self, pos_weight=1., neg_weight=1., reduction='mean', label_smoothing=None, scale=1): + super().__init__() + assert reduction in ['mean', None, False] + self.pos_weight = pos_weight + self.neg_weight = neg_weight + self.reduction = reduction + self.label_smoothing = label_smoothing + self.scale = scale + + def forward(self, pr, gt): + if len(gt.shape) < len(pr.shape): + gt = gt.unsqueeze(axis=-1) + loss = F.binary_crossentropy( + pr, gt, + pos_weight=self.pos_weight, + neg_weight=self.neg_weight, + label_smoothing=self.label_smoothing, + ) + + if self.reduction == 'mean': + loss = loss.mean() + + return loss * self.scale + + +class BinaryClassBCELoss(base.Loss): # + + def __init__(self, pos_weight=1., neg_weight=1., reduction='mean', label_smoothing=None): + super().__init__() + assert reduction in ['mean', None, False] + self.pos_weight = pos_weight + self.neg_weight = neg_weight + self.reduction = reduction + self.label_smoothing = label_smoothing + + def forward(self, pr, gt): + # TODO cal whole image class label + + loss = F.binary_crossentropy( + pr, gt, + pos_weight=self.pos_weight, + neg_weight=self.neg_weight, + label_smoothing=self.label_smoothing, + ) + + if self.reduction == 'mean': + loss = loss.mean() + + return loss + + +class BinaryFocalLoss(base.Loss): + def __init__(self, alpha=1, gamma=2, class_weights=None, logits=False, reduction='mean', label_smoothing=None): + super().__init__() + assert reduction in ['mean', None] + self.alpha = alpha + self.gamma = gamma + self.logits = logits + self.reduction = reduction + self.class_weights = class_weights if class_weights is not None else 1. + self.label_smoothing = label_smoothing + + def forward(self, pr, gt): + if self.logits: + bce_loss = nn.functional.binary_cross_entropy_with_logits(pr, gt, reduction='none') + else: + bce_loss = F.binary_crossentropy(pr, gt, label_smoothing=self.label_smoothing) + + pt = torch.exp(- bce_loss) + focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss + focal_loss = focal_loss * torch.tensor(self.class_weights).to(focal_loss.device) + + if self.reduction == 'mean': + focal_loss = focal_loss.mean() + + return focal_loss + + +class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss): + pass + + +class FocalDiceLoss(base.Loss): + + def __init__(self, lamdba=2): + super().__init__() + self.lamdba = lamdba + self.focal = BinaryFocalLoss() + self.dice = DiceLoss(eps=10.) + + def __call__(self, y_pred, y_true): + return self.lamdba * self.focal(y_pred, y_true) + self.dice(y_pred, y_true) + + +class BCEDiceLoss(base.Loss): + + def __init__(self, lamdba=2): + super().__init__() + self.lamdba = lamdba + self.bce = BCELoss() + self.dice = DiceLoss(eps=10.) + + def __call__(self, y_pred, y_true): + # print(y_pred[0,1,:,:]) + # print(torch.max(y_pred[0,1,:,:])) + # print(torch.min(y_pred[0, 1, :, :])) + # print(torch.mean(y_pred[0, 1, :, :])) + # print(y_true[3,:,:]) + # print(torch.max(y_true[3, :, :])) + # print(torch.min(y_true[3, :, :])) + if y_pred.shape[1] > 1: + y_pred = torch.sigmoid(y_pred) + # y_pred = torch.sigmoid(torch.softmax(y_pred, dim=1)) + y_pred = torch.unsqueeze(y_pred[:, 1, :, :], dim=1) + y_true = torch.unsqueeze(y_true, dim=1).float() # TODO + y_true[y_true == 255] = 0 + # print(y_pred.shape) + # print(y_pred.dtype) + # print(y_true.shape) + # print(y_true.dtype) + return self.lamdba * self.bce(y_pred, y_true) + self.dice(y_pred, y_true) diff --git a/utils/pyt_utils.py b/utils/pyt_utils.py index ff45832..d7331cb 100644 --- a/utils/pyt_utils.py +++ b/utils/pyt_utils.py @@ -191,6 +191,7 @@ def load_model(model, model_file, is_restore=False): return model + def parse_devices(input_devices): if input_devices.endswith('*'): devices = list(range(torch.cuda.device_count())) diff --git a/utils/visualize.py b/utils/visualize.py index 246199f..e3823c8 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -74,3 +74,24 @@ def print_iou(iou, freq_IoU, mean_pixel_acc, pixel_acc, class_names=None, show_n return line +def get_class_colors(num_classes): + def uint82bin(n, count=8): + """returns the binary of integer n, count refers to amount of bits""" + return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) + + N = num_classes + 1 + cmap = np.zeros((N, 3), dtype=np.uint8) + for i in range(N): + r, g, b = 0, 0, 0 + id = i + for j in range(7): + str_id = uint82bin(id) + r = r ^ (np.uint8(str_id[-1]) << (7 - j)) + g = g ^ (np.uint8(str_id[-2]) << (7 - j)) + b = b ^ (np.uint8(str_id[-3]) << (7 - j)) + id = id >> 3 + cmap[i, 0] = r + cmap[i, 1] = g + cmap[i, 2] = b + class_colors = cmap.tolist() + return class_colors \ No newline at end of file