Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement strategy for assessing the quality of the model during lifelong training #87

Open
jcohenadad opened this issue Nov 9, 2023 · 9 comments
Assignees

Comments

@jcohenadad
Copy link
Member

As we are adding more contrasts and re-training the model overtime (see eg: #83, #74, ivadomed/canproco#46), we need to put in place a quality check assessment of model performance shift across various data domains (ie: monitor catastrophic forgetting).

@naga-karthik
Copy link
Collaborator

A relevant theory paper I found: Understanding Continual Learning Settings with Data Distribution Drift Analysis -- Essentially describes the theory of data distribution shifts, proposes new concepts in analyzing model/data drifts and some of the existing concepts in lifelong learning that are related to this phenomenon.

Relevant sections: Sections 3.1, 3.2, 4.1, and 6.2

@jcohenadad
Copy link
Member Author

one idea of validation is to compute the CSA variation across contrasts from the test set of the spine generic data

@naga-karthik
Copy link
Collaborator

naga-karthik commented Nov 29, 2024

with each new release adding more contrasts and datasets, we observed a disturbing upward trend in absolute csa error across all the deployed models, suggesting that the drift is way too high and the model is losing in "contrast-agnostic"-ness in some sense.

abs. csa error drift across deployed models

abs_csa_error_perslice

The no. 1 suspicion is the dataset imbalance created by an unusually high number of T2w and T2star images in the training set of the later models (i.e. v2.4, v2.5 etc)

Next step:

  • create a balanced dataset with similar number of images for each contrast

@naga-karthik
Copy link
Collaborator

naga-karthik commented Nov 29, 2024

I created a balanced version of the aggregated dataset. Below are the details and the splits. Each contrast now has approximately 150 images (a hard-coded value after looking at how many total number of images each contrast has). Except for a few contrasts (e.g. STIR) all contrasts have at least 150 images.

balanced dataset statistics
SPLITS ACROSS DIFFERENT CONTRASTS (n=11):

| contrast   |   train |   validation |   test |   #images_per_contrast |
|:-----------|--------:|-------------:|-------:|-----------------------:|
| dwi        |     150 |          102 |      0 |                    252 |
| mt-off     |     146 |           97 |      0 |                    243 |
| mt-on      |     150 |           98 |      0 |                    248 |
| part-mag   |     122 |           54 |      0 |                    176 |
| psir       |     150 |           96 |      0 |                    246 |
| stir       |      57 |           32 |      0 |                     89 |
| t1map      |      61 |           27 |      0 |                     88 |
| t1w        |     150 |          127 |      0 |                    277 |
| t2star     |     150 |          150 |    116 |                    416 |
| t2w        |     150 |          150 |    150 |                    450 |
| unit1      |     132 |           59 |      0 |                    191 |
| TOTAL      |    1418 |          992 |    266 |                   2676 |


CONTRAST-WISE PATHOLOGY SPLIT (a subject can have multiple contrasts):
includes train/val/test images

|          |   MS |   HC |   RIS |   RRMS |   PPMS |   DCM |   SPMS |   SCI |   ALS |   SYR |   NMO |   #total_per_contrast |
|:---------|-----:|-----:|------:|-------:|-------:|------:|-------:|------:|------:|------:|------:|----------------------:|
| dwi      |    1 |  173 |     0 |      0 |      0 |    68 |      0 |     4 |     5 |     1 |     0 |                   252 |
| mt-off   |    0 |  184 |     0 |      0 |      0 |    59 |      0 |     0 |     0 |     0 |     0 |                   243 |
| mt-on    |    0 |  179 |     0 |      0 |      0 |    65 |      0 |     0 |     4 |     0 |     0 |                   248 |
| part-mag |    0 |    0 |     0 |    138 |     10 |     0 |     28 |     0 |     0 |     0 |     0 |                   176 |
| psir     |    0 |   29 |    40 |    146 |     31 |     0 |      0 |     0 |     0 |     0 |     0 |                   246 |
| stir     |    0 |   10 |     7 |     56 |     16 |     0 |      0 |     0 |     0 |     0 |     0 |                    89 |
| t1map    |    0 |    0 |     0 |     69 |      5 |     0 |     14 |     0 |     0 |     0 |     0 |                    88 |
| t1w      |   17 |  204 |     0 |      0 |      0 |    43 |      0 |     0 |     0 |     1 |    12 |                   277 |
| t2star   |  207 |  125 |     0 |      0 |      0 |    67 |      0 |     2 |    14 |     1 |     0 |                   416 |
| t2w      |  119 |   75 |    12 |     54 |     14 |   109 |      0 |    62 |     5 |     0 |     0 |                   450 |
| unit1    |   50 |   53 |     0 |     69 |      5 |     0 |     14 |     0 |     0 |     0 |     0 |                   191 |
| TOTAL    |  394 | 1032 |    59 |    532 |     81 |   411 |     56 |    68 |    28 |     3 |    12 |                  2676 |

Next step:

  • train a model on the balanced dataset, compute abs. csa error, compare with deployed models

@naga-karthik
Copy link
Collaborator

Something interesting happened -- it seems that an imbalance in the no. of images per contrast in the aggregated dataset is not causing the CSA drift.

abs. csa error plot deployed models vs balanced model

abs_csa_error_perslice

csa error per contrast for the balanced model

abs_error_per_contrast_balanced

Note the per-contrast CSA errors are high for T2star, DWI and MToff (GRE-T1w) contrasts, but in the previous comment, we see from the dataset statistics these contrasts have 150 images each in the training set.

I am wondering if any of the new contrasts added to the training set (after the original 6 contrasts in contrast-agnostic v2.0) are negatively impacting the performance on the remaining contrasts.

@sandrinebedard
Copy link
Member

interesting investigation! Do you have the aboslute CSA error per contrast for each model version?

@naga-karthik
Copy link
Collaborator

yes!

csa error per contrast for the balanced model

abs_error_per_contrast_balanced

csa error per contrast for the v20 model

abs_error_per_contrast_v20

csa error per contrast for the v21 model

abs_error_per_contrast_v21

csa error per contrast for the v23 model

abs_error_per_contrast_v23

csa error per contrast for the v24 model

abs_error_per_contrast_v24

csa error per contrast for the v25 model

abs_error_per_contrast_v25

@sandrinebedard
Copy link
Member

nice! why in v2.5 DWI has a lower error than in the balanced? (3.39 mm2 vs 12.34)

@naga-karthik
Copy link
Collaborator

Figuring that the "dataset imbalancement" is not the only problem, I trained another model with only DCM pathology data and 1 new contrast (more like how model v2.3 was trained).

csa error per contrast for the balanced model trained on spine-generic, UNIT1, dcm datasets

abs_error_per_contrast_balanced

abs csa error plot

abs_csa_error_perslice

It turns out that the per contrast error is still high for a few contrast despite not seeing several new contrasts. This makes me wonder if adding new contrasts is not the problem but maybe, in this case, too much compressed cords in the data is a problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants