Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code and add example notebooks #19

Merged
merged 1 commit into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install tensorflow-cpu==2.12.0
pip install jax==0.4.14
pip install jaxlib==0.4.14
pip install jax==0.4.20
pip install jaxlib==0.4.20
pip install -r docker/requirements.txt
pip install -e .
- name: Test with pytest
run: |
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow"
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow and not integration"
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.nox/
.coverage
.coverage*
.coverage.*
.cache
nosetests.xml
Expand Down Expand Up @@ -137,6 +137,7 @@ dmypy.json
# notebook
*.ipynb
notebooks/
!examples/segmentation/inference.ipynb

# hydra outputs
outputs/
Expand Down
32 changes: 7 additions & 25 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ repos:
rev: v4.5.0
hooks:
- id: check-added-large-files
args: ["--maxkb=15000"]
- id: check-ast
- id: check-byte-order-marker
- id: check-builtin-literals
Expand All @@ -27,43 +28,24 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.10.0
rev: 23.11.0
hooks:
- id: black
args:
- --line-length=100
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.7.1
hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665
- id: mypy
name: mypy-imgx
files: ^imgx/
entry: mypy imgx/
pass_filenames: false
args:
[
--strict-equality,
--disallow-untyped-calls,
--disallow-untyped-defs,
--disallow-incomplete-defs,
--check-untyped-defs,
--disallow-untyped-decorators,
--warn-redundant-casts,
--warn-unused-ignores,
--no-warn-no-return,
--warn-unreachable,
]
- id: mypy
name: mypy-imgx_datasets
files: ^imgx_datasets/
entry: mypy imgx_datasets/
name: mypy
pass_filenames: false
args:
[
--strict-equality,
--disallow-untyped-calls,
--disallow-untyped-defs,
--disallow-incomplete-defs,
--disallow-any-generics,
--check-untyped-defs,
--disallow-untyped-decorators,
--warn-redundant-casts,
Expand All @@ -72,15 +54,15 @@ repos:
--warn-unreachable,
]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
rev: v3.1.0
hooks:
- id: prettier
args:
- --print-width=100
- --prose-wrap=always
- --tab-width=2
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.1.1"
rev: "v0.1.6"
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-pylint
Expand Down
48 changes: 34 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
# ImgX-DiffSeg

ImgX-DiffSeg is a Jax-based deep learning toolkit (now using Flax) for biomedical image
segmentation.
ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentations.

This repository currently includes the implementation of the following work
This repository includes the implementation of the following work

- [A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models](https://arxiv.org/abs/2308.16355)
- [Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation](https://arxiv.org/abs/2303.06040)

:construction: **The codebase is still under active development for more enhancements and
applications.** :construction:

- November 2023:
- :warning: Upgrade to JAX to 0.4.20.
- :warning: Removed Haiku specific modification to convolutional layers. This may impact model
performance.
- :smiley: Added example notebooks for inference on single image without TFDS.
- Added integration tests for training, validation and testing.
- Refactored config.
- Added `patch_size` and `scale_factor` to data config.
- Moved loss config from main config to task config.
- Refactored code, including defining `imgx/task` submodule.
- October 2023: :sunglasses: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.

:mailbox: Please feel free to
[create an issue](https://github.com/mathpluscode/ImgX-DiffSeg/issues/new/choose) to request
features or [reach out](https://orcid.org/0000-0002-1184-7421) for collaborations. :mailbox:
Expand Down Expand Up @@ -61,11 +73,6 @@ See the [readme](imgx_datasets/README.md) for further details.
- Gradient clipping and accumulation.
- [Early stopping](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html).

**Changelog**

- October 2023: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.

## Installation

### TPU with Docker
Expand Down Expand Up @@ -112,8 +119,7 @@ The following instructions have been tested only for TPU-v3-8. The docker contai

### GPU with Docker

The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker
container uses non-root user.
CUDA >= 11.8 is required. The docker container uses non-root user.
[Docker image used may be removed.](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md)

1. Build the docker image inside the repository.
Expand Down Expand Up @@ -141,7 +147,7 @@ container uses non-root user.
where

- `--rm` removes the container once exit it.
- `-v` maps the `ImgX` folder into container.
- `-v` maps the current folder into container.

3. Install the package inside container.

Expand Down Expand Up @@ -214,12 +220,10 @@ export DATASET_NAME="brats2021_mr"

# Vanilla segmentation
imgx_train data=${DATASET_NAME} task=seg
imgx_valid --log_dir wandb/latest-run/
imgx_test --log_dir wandb/latest-run/

# Diffusion-based segmentation
imgx_train data=${DATASET_NAME} task=gaussian_diff_seg
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
Expand Down Expand Up @@ -259,10 +263,26 @@ Run the command below to test and get coverage report. As JAX tests requires two
threads, therefore requires 8 CPUs in total.

```bash
pytest --cov=imgx -n 4 imgx
pytest --cov=imgx -n 4 imgx -k "not integration"
pytest --cov=imgx_datasets -n 4 imgx_datasets
```

`-k "not integration"` excludes integration tests, which requires downloading muscle ultrasound and
amos CT data sets.

For integration tests, run the command below. `-s` enables the print of stdout. This test may take
40-60 minutes.

```bash
pytest imgx/integration_test.py -s
```

To test the jupyter notebooks, run the command below.

```bash
pytest --nbmake examples/**/*.ipynb
```

## References

- [Segment Anything (PyTorch)](https://github.com/facebookresearch/segment-anything)
Expand Down
6 changes: 3 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ COPY docker/requirements.txt /${USER}/requirements.txt

RUN /${USER}/conda/bin/pip3 install --upgrade pip \
&& /${USER}/conda/bin/pip3 install \
jax==0.4.14 \
jaxlib==0.4.14+cuda11.cudnn86 \
jax==0.4.20 \
jaxlib==0.4.20+cuda11.cudnn86 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& /${USER}/conda/bin/pip3 install tensorflow-cpu==2.12.0 \
&& /${USER}/conda/bin/pip3 install tensorflow-cpu==2.14.0 \
&& /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt

RUN git config --global --add safe.directory /${USER}/ImgX
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM mambaorg/micromamba:0.27.0 as conda
FROM mambaorg/micromamba:1.5.1 as conda

# Speed up the build, and avoid unnecessary writes to disk
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONDONTWRITEBYTECODE=1 PYTHONUNBUFFERED=1
Expand Down
8 changes: 4 additions & 4 deletions docker/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ channels:
- defaults
dependencies:
- python=3.9
- pip=23.0.1
- pip=23.3.1
- pip:
- tensorflow-cpu==2.13.0
- jax==0.4.14
- jaxlib==0.4.14
- tensorflow-cpu==2.14.0
- jax==0.4.20
- jaxlib==0.4.20
- -r requirements.txt
10 changes: 5 additions & 5 deletions docker/environment_mac_m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ channels:
- defaults
dependencies:
- python=3.9
- pip=23.0.1
- pip=23.3.1
- pip:
- tensorflow-macos==2.13.0
- tensorflow-metal==1.0.1
- jax==0.4.14
- jaxlib==0.4.14
- tensorflow-macos==2.14.0
- tensorflow-metal==1.1.0
- jax==0.4.20
- jaxlib==0.4.20
- -r requirements.txt
8 changes: 4 additions & 4 deletions docker/environment_tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ channels:
- conda-forge
dependencies:
- python=3.9
- pip=23.0.1
- pip=23.3.1
- pip:
- --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
- tensorflow-cpu==2.13.0
- jax[tpu]==0.4.14
- jaxlib==0.4.14
- tensorflow-cpu==2.14.0
- jax[tpu]==0.4.20
- jaxlib==0.4.20
- -r requirements.txt
30 changes: 16 additions & 14 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
SimpleITK==2.3.0
SimpleITK==2.3.1
chex==0.1.8
coverage==7.3.1
flax==0.7.4
coverage==7.3.2
flax==0.7.5
hydra-core==1.3.2
kaggle==1.5.16
numpy==1.24.3 # limited by tensorflow-macos 2.13.0
opencv-python==4.8.0.76
nbmake==1.4.6
numpy==1.26.2
opencv-python==4.8.1.78
optax==0.1.7
pandas==2.1.1
pre-commit==3.4.0
pandas==2.1.3
pre-commit==3.5.0
protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858
pytest-cov==4.1.0
pytest-mock==3.12.0
pytest-randomly==3.15.0
pytest-split==0.8.1
pytest-xdist==3.3.1
pytest==7.4.2
pytest-xdist==3.5.0
pytest==7.4.3
rdkit-pypi==2022.9.5
rich==13.5.3
ruff==0.0.291
rich==13.7.0
ruff==0.1.6
tensorflow-datasets==4.9.3
torch==2.0.1 # for testing only
wandb==0.15.11
wily==1.24.2
torch==2.1.1 # for testing only
wandb==0.16.0
wily==1.25.0
Binary file added examples/segmentation/BB_anon_348_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/segmentation/BB_anon_348_1_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions examples/segmentation/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
data:
name: muscle_us
loader:
max_num_samples_per_split: -1
patch_shape:
- 480
- 512
patch_overlap:
- 0
- 0
data_augmentation:
max_rotation:
- 0.088
max_translation:
- 10
- 10
max_scaling:
- 0.15
- 0.15
trainer:
max_num_samples: 512000
batch_size: 64
batch_size_per_replica: 8
num_devices_per_replica: 1
patch_size:
- 2
- 2
scale_factor:
- 2
- 2
task:
name: segmentation
model:
_target_: imgx.model.Unet
remat: true
num_spatial_dims: 2
patch_size:
- 2
- 2
scale_factor:
- 2
- 2
num_channels:
- 8
- 16
- 32
- 64
out_channels: 2
num_heads: 8
widening_factor: 4
num_transform_layers: 1
loss:
dice: 1.0
cross_entropy: 0.0
focal: 20.0
early_stopping:
metric: mean_binary_dice_score_without_background
mode: max
min_delta: 0.0001
patience: 10
debug: false
seed: 0
half_precision: true
optimizer:
name: adamw
kwargs:
b1: 0.9
b2: 0.999
weight_decay: 1.0e-08
grad_norm: 1.0
lr_schedule:
warmup_steps: 100
decay_steps: 10000
init_value: 1.0e-05
peak_value: 0.0008
end_value: 5.0e-05
logging:
root_dir: null
log_freq: 10
save_freq: 100
wandb:
project: imgx
entity: entity
Binary file not shown.
Loading