Skip to content

Latest commit

 

History

History
140 lines (92 loc) · 5.07 KB

README.md

File metadata and controls

140 lines (92 loc) · 5.07 KB

UNet 3+ Unofficial Pytorch Implementation

This code is Implementation of UNet 3+ in pytorch.

I refered to Tensorflow Implementation of UNet 3+ github.

UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation Hits license

PWC

Hit star ⭐ if you find my work useful.

Table of Contents

Installation

Requirements

  • Python >= 3.10
  • Pytorch >= 2.2.0
  • CUDA 12.0

This code base is tested against above-mentioned Python and Pytorch versions. But it's expected to work for latest versions too.

  • Clone code
git clone https://github.com/russel0719/UNet-3-Plus-Pytorch.git UNet3P
cd UNet3P
  • Install other requirements.
pip install -r requirements.txt

Code Structure

  • checkpoint: Model checkpoint and logs directory
  • configs: Configuration file
  • data: Dataset files (see Data Preparation) for more details
  • data_preparation: For LiTS data preparation and data verification
  • losses: Implementations of UNet3+ hybrid loss function and dice coefficient
  • models: Unet3+ model files
  • utils: Generic utility functions
  • data_generator.py: Data generator for training, validation and testing
  • predict.py: Prediction script used to visualize model output
  • train.py: Training script

Config

Configurations are passed through yaml file. For more details on config file read here.

Data Preparation

For dataset preparation read here.

Models

This repo contains all three versions of UNet3+.

# Description Model Name Training Supported
1 UNet3+ Base model unet3plus
2 UNet3+ with Deep Supervision unet3plus_deepsup
3 UNet3+ with Deep Supervision and Classification Guided Module unet3plus_deepsup_cgm
  • But we can train unet3plus_deepsup_cgm only with OUTPUT.CLASSES = 1 option

Here is a sample code for UNet 3+

INPUT_SHAPE = [1, 320, 320]
OUTPUT_CHANNELS = 1

unet_3P = UNet3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, deep_supervision=False, CGM=False)
unet_3P_deep_sup = UNet3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, deep_supervision=True, CGM=False)
unet_3P_deep_sup_cgm = UNet3Plus(INPUT_SHAPE, OUTPUT_CHANNELS, deep_supervision=True, CGM=True)

Here you can find UNet3+ hybrid loss.

Training

To train a model on train dataset call train.py with required model type and configurations .

e.g. To train on base model run

python train.py MODEL.TYPE=unet3plus

Validation

To validate a model on valid dataset call validate.py with required model type and configurations .

e.g. To validate on base model and visualize them, run

python validate.py MODEL.TYPE=unet3plus

Inferencing

To inference a model on valid dataset call predict.py with required model type and configurations .

e.g. To inference on base model run

python predict.py MODEL.TYPE=unet3plus

Acknowledgement

We appreciate any feedback so reporting problems, and asking questions are welcomed here.

Licensed under MIT License