Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training SynthSeg on spine-generic data #111

Open
naga-karthik opened this issue Jun 20, 2024 · 8 comments
Open

Training SynthSeg on spine-generic data #111

naga-karthik opened this issue Jun 20, 2024 · 8 comments

Comments

@naga-karthik
Copy link
Collaborator

naga-karthik commented Jun 20, 2024

This issue summarizes my attempt at training the SynthSeg on the spine-generic multi subject dataset.

Brief context: SynthSeg was originally proposed for the segmentation of brain scans of any resolution and contrast. Because is "contrast-agnostic", I am trying to re-train the model as a baseline for our contrast-agnostic spinal cord (SC) segmentation model. Important note here is that SynthSeg required fully-labeled brain scans from which it synthesizes fake brain images (sampled from a GMM) and then trains the segmentation model. A notable challenge in re-training SynthSeg for SC images is that it would require all the labels in a SC scan (i.e. those of SC, cerebrospinal fluid, vertebrae, bones, brain, etc.)

As it is not feasible to obtain labels for each anatomical region in a SC image, I tried to ease the constraints by only focusing on the segmentations of 4 parts: (i) vertebrae, (ii) discs, (iii) SC, and (iv) cerebrospinal fluid (CSF). The labels for these regions were obtained using the TotalSegmentatorMRI model.

Experiments

The SynthSeg repo contains well-described tutorials for re-training your own models. The following sections will describe how I have tweaked SynthSeg for SC data.

Defining the labels for generating synthetic images

There are 4 key elements in generating synthetic images based on labels:

  1. Defining the path to the training label maps: In my case, this corresponds to the folder containing the TotalSeg labels. I only used the labels for T1w and T2w contrasts as the labels were not available for the rest of the spine-generic contrasts.
  2. Label classes for generation: In my case, these correspond to the ids for the background, vertebrae, discs, SC and CSF. Following the preliminary version of TotalSpineSeg model, I used the following ids to define the labels. Note that using the same ids in the tutorial didn't make sense as those labels correspond to the brain (and not the SC)
gen_labels_cerv_totalseg = np.array([0, 
    # vertebrae (covering C1 to T12)
    41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23,
    # discs (covering C1/C2 to T11/T12)
    224, 223, 222, 221, 220, 219, 218, 217, 216, 215, 214, 213, 212, 211, 210, 209, 208,
    # SC and CSF
    200, 201,
])
  1. Defining the output label maps: To keep it simple, I used the same labels as above, meaning that the synthetic images will have all classes that are defined in the label generation array above. (SynthSeg provides an option to disable some labels if you don't want to generate them). Also, to keep it simple, all the vertebrae correspond to one class and all the discs correspond to one class (class id shown below).
  2. Defining the generation classes: This corresponds to the output classes that the segmentation model learns to output during training. Unlike a random collection of ids as shown above, these are typically incremental ids starting from 0 to the total number of classes. In my case, I provided the following 4 classes:
1: SC
2: CSF
3: Vertebrae
4: Discs

These are essentially the 4 key parameters in the procedure for generating synthetic images from the GMM.

@naga-karthik
Copy link
Collaborator Author

naga-karthik commented Jun 20, 2024

The following are a few examples of synthetic images from the GMM

Example 1

ezgif com-animated-gif-maker

Example 2

ezgif com-animated-gif-maker-2

Example 3

ezgif com-animated-gif-maker-3

As expected, the synthetic images look very weird -- which is the point of SynthSeg i.e. the heavy deformation and/or unrealistic images -- which makes it contrast-agnostic.

@naga-karthik
Copy link
Collaborator Author

Training

The training procedure borrows several aspects of the synthetic image generation procedure (described in the 1st comment) and adds the regular training hyperparameters. To keep it close to our contrast-agnostic model for SC, I set the output shape to 64 x 192 x 256. I also set bias_field_std=0.5 as this is the range we chose in our transforms.

When training for 5 epochs initially, the model encountered NaN values in the loss. It seems that loss values were also encountered here when re-training SynthSeg. More debugging to be continued ...

@naga-karthik
Copy link
Collaborator Author

naga-karthik commented Jul 3, 2024

I was able to fix the NaN issue and the first set of results are in. Each following comment describes the experiment along with the results obtained. The generation labels and the segmentation classes remain the same as described in #111 (comment). A total of 209 labeled cervical spine scans were used consisting of labels for the cord, CSF, vertebrae and intervertebral discs obtained from T1w and T2w contrasts. (Note, SynthSeg primarily used the fully-labeled brain scans for T1w contrast)

Experiment 1

The following hyperparameters are defined in this training script, which I have modified to stay close to our contrast-agnostic model for SC segmentation. The hyperparams defined below are the only ones changed for this experiment, rest all are set to their default values.

n_neutral_labels = None   # (setting this to an incorrect value is what I believe resulted in NaN values above
# shape and resolution of the outputs
target_res = None   # SynthSeg by default outputs everything at 1mm iso; since we don't want that, it is set to None  
output_shape = (64, 192, 256)    # setting this to the original shape of the labels (192, 260, 320) resulted in OOM errors;  
                                 # the cropped output shape is also close to what we use in contrast-agnostic (64, 192, 320) i.e. heavy cropping on R-L dim. 

# spatial deformation parameters
flipping = False     
rotation_bounds = 20  
shearing_bounds = .012
bias_field_std = .5

# architecture parameters
activation = 'relu'     # we used relu, so changed it from 'elu' to 'relu'

# training parameters
dice_epochs = 25       # number of training epochs
steps_per_epoch = 1000  # number of iteration per epoch
batchsize = 2

Results

Test sample 1

sub-barcelona05_T1w

Screenshot 2024-07-03 at 9 51 38 AM
Test sample 2

sub-amu05_T1w

Screenshot 2024-07-03 at 9 56 18 AM

We see that the predictions are not good as they fail to capture the structure of the cervical spine. None of the labels (i.e. cord, CSF, vertebrae, discs, etc.) are properly segmented.

@naga-karthik
Copy link
Collaborator Author

Experiment 2

Since flipping=True was used by default in SynthSeg, I trained another model this time setting flipping to True and keeping rest of the hyperparams defined in #111 (comment) to be the same.

Test sample 1

sub-barcelona05_T1w

Screenshot 2024-07-03 at 10 08 41 AM

Surprisingly, the model does not predict anything when flipping=True during training

@naga-karthik
Copy link
Collaborator Author

naga-karthik commented Jul 4, 2024

Experiment 3

To simplify the segmentation problem, I tried to re-train the model with only 2 output classes (i.e. those of the spinal cord and the CSF). Specifically, given all the following labels, the model is trained to output only the SC and CSF classes with label values 200 and 201 (keeping the rest of the hyperparams same as in Experiment 1)

gen_labels_cerv_totalseg = np.array([0, 
    # vertebrae (covering C1 to T12)
    41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23,
    # discs (covering C1/C2 to T11/T12)
    224, 223, 222, 221, 220, 219, 218, 217, 216, 215, 214, 213, 212, 211, 210, 209, 208,
    # SC and CSF --> only output these labels
    200, 201,
])

The model failed to output anything as we get empty predictions for all test images here as well ...

Test sample 1

sub_brnoUhb04_T1w

Screenshot 2024-07-04 at 11 55 28 AM

Given the string of poor results, I am not sure what could be the actual reason for SynthSeg failing on SC images. One potential issue could be that, at test time, the input T1w images contain both brain and the SC (like any typical T1w SC image). However, the synthetic scans that SynthSeg generates during training only contains the images with 4 classes -- maybe this discrepancy is one issue? But, to fix this, it would mean that we need all the labels (i.e. for those of the brain and those of the SC, which is really impractical)

@jcohenadad
Copy link
Member

But, to fix this, it would mean that we need all the labels (i.e. for those of the brain and those of the SC, which is really impractical)

indeed-- let's not go there

@naga-karthik
Copy link
Collaborator Author

naga-karthik commented Jan 2, 2025

Following an excellent clarification from the reviewer, I realized that I had indeed misunderstood how SynthSeg was generating the synthetic scans. Since the beginning, we have been assuming that SynthSeg requires the labels for all the structures (i.e. those of soft tissues around the cord, the brain i.e. all those not the spinal cord). But that isn't the case. SynthSeg can work even without the GT labels for the "non-interesting structures" (in our case, everything outside the cord segmentation). In the paper, for the cardiac segmentation example, they generated the labels outside the cardiac part by clustering the image based on intensities using an expectation-maximization (EM) algorithm. While EM-based clustering is not a hard requirement, the labels for non-interesting structures can also be created using the (much simpler) KMeans algorithm. What follows is the description of my attempt to create KMeans-based labels for our SC scans.

Following the approach mentioned in Supplement 7 (Figure S3 caption), I kept the GT cord segmentation fixed and generated labels for additional (non-interesting) structures using 3-10 clusters. Script for this can be found here. Inputs are T1w scans and their corresponding GT cord masks. Outputs are multi-class labels where 0 is the background, 1 is the cord and 2-11 are the set of labels obtained with Kmeans clustering. Some examples of the synthetic scans are below.

Example 1

Image

Example 2

Image

Example 3

Image

The synthetic images look similar to the ones show in Supplement 1, Figure S1, confirming that this approach for clustering (for non-brain images) and generation is valid. Note that to generate the images, I followed the 2-generation_explained.py tutorial, where: (1) training labels maps contain all classes (0-11), (2) generation_classes are defined to all the classes in the training label maps such that SynthSeg generates the intensities for all anatomical structures (hence resulting in the samples shown above), and (3) the synthseg_segmentation_labels are defined to be only 0 and 1 so that the target label is only the background and the cord (same as what we have for our contrast-agnostic model). Note also that this edition of revisiting the SynthSeg comparison does not use TotalSeg labels (instead, we simply start with the already-available GT cord segmentations)

Next step is train SynthSeg on our data and do a proper comparison with the contrast-agnostic model

@naga-karthik
Copy link
Collaborator Author

Results from SynthSeg training are in! Thanks to the proper implementation of the clustering, the model did predict something this time! A few examples are shown below:

sub-balgrist06_T1w_synthseg

Image

sub-balgrist06_T2w_synthseg

Image

We see that SynthSeg's prediction for T1w is relatively quite good compared to the one for T2w. For T2star and MTon contrasts, the did not predict anything. Note that I only used the T1w images to do the intensity-based clustering following Table 1 in the original paper. Based on these results, it seems that SynthSeg is "contrast-agnostic" only as long as the field-of-view of the images are the same. For T2w, the model did predict parts of the cord only because the field-of-view/resolution is similar to that of T1w. For the rest of the contrasts, the images are highly anistropic and only contain partial views of the cord (in the S-I plane).

Next step: Running one more experiment where 4 contrasts (i.e. T1w, T2w, T2star, DWI) are used when generating labels using K-Means clustering. Maybe exposure to the labels covering all kinds of fields-of-view in the dataset would improve SynthSeg performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants