Skip to content

Commit

Permalink
udpate phase prediction model
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Jul 22, 2024
1 parent f8dac08 commit 3054b2b
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Master
* add brain structures
* add liver vessels
* greatly improved phase classification model


## Release 2.3.0
Expand Down
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,6 @@ If you want to reduce memory consumption you can use the following options:
* `--nr_thr_saving 1`: Saving big images with several threads will take a lot of memory


### Train/validation/test split
The exact split of the dataset can be found in the file `meta.csv` inside of the [dataset](https://doi.org/10.5281/zenodo.6802613). This was used for the validation in our paper.
The exact numbers of the results for the high-resolution model (1.5mm) can be found [here](resources/results_all_classes_v1.json). The paper shows these numbers in the supplementary materials Figure 11.


### Retrain model and run evaluation
See [here](resources/train_nnunet.md) for more info on how to train a nnU-Net yourself on the TotalSegmentator dataset, how to split the data into train/validation/test set as in our paper, and how to run the same evaluation as in our paper.


### Python API
You can run totalsegmentator via Python:
```python
Expand All @@ -159,13 +150,12 @@ if __name__ == "__main__":
```
You can see all available arguments [here](https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/python_api.py). Running from within the main environment should avoid some multiprocessing issues.

The segmentation image contains the names of the classes in the extended header. If you want to load this additional header information you can use the following code:
The segmentation image contains the names of the classes in the extended header. If you want to load this additional header information you can use the following code (requires `pip install xmltodict`):
```python
from totalsegmentator.nifti_ext_header import load_multilabel_nifti

segmentation_nifti_img, label_map_dict = load_multilabel_nifti(image_path)
```
The above code requires `pip install xmltodict`.


### Install latest master branch (contains latest bug fixes)
Expand All @@ -175,6 +165,11 @@ pip install git+https://github.com/wasserth/TotalSegmentator.git


### Other commands
If you want to know which contrast phase a CT image is you can use the following command (requires `pip install xgboost`). More details can be found [here](resources/contrast_phase_prediction.md):
```
totalseg_get_phase -i ct.nii.gz -o contrast_phase.json
```

If you want to combine some subclasses (e.g. lung lobes) into one binary mask (e.g. entire lung) you can use the following command:
```
totalseg_combine_masks -i totalsegmentator_output_dir -o combined_mask.nii.gz -m lungcomm
Expand All @@ -191,6 +186,15 @@ totalseg_set_license -l aca_12345678910
```


### Train/validation/test split
The exact split of the dataset can be found in the file `meta.csv` inside of the [dataset](https://doi.org/10.5281/zenodo.6802613). This was used for the validation in our paper.
The exact numbers of the results for the high-resolution model (1.5mm) can be found [here](resources/results_all_classes_v1.json). The paper shows these numbers in the supplementary materials Figure 11.


### Retrain model and run evaluation
See [here](resources/train_nnunet.md) for more info on how to train a nnU-Net yourself on the TotalSegmentator dataset, how to split the data into train/validation/test set as in our paper, and how to run the same evaluation as in our paper.


### Typical problems

**ITK loading Error**
Expand Down
Binary file removed resources/contrast_phase_classifiers.pkl
Binary file not shown.
Binary file not shown.
30 changes: 30 additions & 0 deletions resources/contrast_phase_prediction.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Details on how the prediction of the contrast phase is done

TotalSegmentator is used to predict the following structures:
```python
["liver", "pancreas", "urinary_bladder", "gallbladder",
"heart", "aorta", "inferior_vena_cava", "portal_vein_and_splenic_vein",
"iliac_vena_left", "iliac_vena_right", "iliac_artery_left", "iliac_artery_right",
"pulmonary_vein", "brain", "colon", "small_bowel",
"internal_carotid_artery_right", "internal_carotid_artery_left",
"internal_jugular_vein_right", "internal_jugular_vein_left"]
```
Then the median intensity (HU value) of each structure is used as feature for a xgboost classifier
to predict the post injection time (pi_time). The pi_time can be mapped to the contrast phase
then. It classifies into `native`, `arterial_early`, `arterial_late`, and `portal_venous` phase.
The classifier was trained on the TotalSegmentator dataset and therefore works with all sorts
of different CT images.

Results on 5-fold cross validation:

- Mean absolute error (MAE): 5.55s
- F1 scores for each class:
- native: 0.980
- arterial_early+late: 0.915
- portal: 0.940

The results contain a probablity for each class which is high if the predicted pi_time is close to the ideal

Check failure on line 26 in resources/contrast_phase_prediction.md

View workflow job for this annotation

GitHub Actions / Check for spelling errors

probablity ==> probability
pi_time for the given phase. Moreover, the classifier is an ensemble of 5 models. The output contains the
standard deviation of the predictions which can be used as a measure of confidence. If it is low the 5 models
give similar predictions which is a good sign.

5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
python_requires='>=3.9',
license='Apache 2.0',
packages=find_packages(),
package_data={"totalsegmentator": ["resources/totalsegmentator_snomed_mapping.csv"]},
package_data={"totalsegmentator":
["resources/totalsegmentator_snomed_mapping.csv",
"resources/contrast_phase_classifiers_2024_07_19.pkl"]
},
install_requires=[
'torch>=2.0.0',
'numpy<2',
Expand Down
18 changes: 10 additions & 8 deletions totalsegmentator/bin/totalseg_get_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_ct_contrast_phase(ct_img: nib.Nifti1Image, model_file: Path = None):
# print(f"ts took: {time.time()-st:.2f}s")

if stats["brain"]["volume"] > 100:
# print(f"Brain in image, therefore also running headneck model.")
# print("Brain in image, therefore also running headneck model.")
st = time.time()
seg_img_hn, stats_hn = totalsegmentator(ct_img, None, ml=True, fast=False, statistics=True,
task="headneck_bones_vessels",
Expand All @@ -85,11 +85,9 @@ def get_ct_contrast_phase(ct_img: nib.Nifti1Image, model_file: Path = None):
features.append(stats_hn[organ]["intensity"])

if model_file is None:
# weights from longitudinalliver dataset
classifier_path = Path(__file__).parents[2] / "resources" / "contrast_phase_classifiers.pkl"
else:
# weights from megaseg dataset
# classifier_path = "/mnt/nor/wasserthalj_data/classifiers_megaseg.pkl"
classifier_path = Path(__file__).parents[2] / "resources" / "contrast_phase_classifiers_2024_07_19.pkl"
else:
# manually set model file
classifier_path = model_file
clfs = pickle.load(open(classifier_path, "rb"))

Expand Down Expand Up @@ -136,13 +134,17 @@ def main():
parser.add_argument("-m", metavar="filepath", dest="model_file",
help="path to classifier model",
type=lambda p: Path(p).absolute(), required=False, default=None)

parser.add_argument("-q", dest="quiet", action="store_true",
help="Print no output to stdout", default=False)

args = parser.parse_args()

res = get_ct_contrast_phase(nib.load(args.input_file), args.model_file)

print("Result:")
pprint(res)
if not args.quiet:
print("Result:")
pprint(res)

with open(args.output_file, "w") as f:
f.write(json.dumps(res, indent=4))
Expand Down

0 comments on commit 3054b2b

Please sign in to comment.