As part of UIUC CS598 Deep Learning for Healthcare course, we have decided to reproduce INPREM: An Interpretable and Trustworthy Predictive Model for Healthcare. The goal of this project is to reproduce the experiment results within the paper and also add an ablation of our choosing to see how it influences the outcome.
The paper implements the INPREM model (and its variations) and compares it to currently widely-used models, such as:
- CNN
- RNN
- RNN+
- Dipole
- RETAIN
The goal of the INPREM model is to be used for clinical prediction tasks, which current models are not very suitable for. Data from MIMIC-III is used, specifically DIAGNOSES_ICD.csv
. INPREM was provided to us by one of the authors, and we implemented the other baseline models in Python. We compared Code-Level Accuracy
and Visit-Level Precision
over 30 visits in 5 visit increments.
Create a python 3.8 environment.
To install requirements:
pip install -r requirements.txt
(Optional) For Windows and Linux users, to enable the models to run on your GPU the appropriate version of the torch must be installed.
pip3 install torch --extra-index-url https://download.pytorch.org/whl/cu113
When running main.py
to train the models you must use the --train
argument along with the model specification, --model {model}
. Replace {model}
with one of the available models.
The default arguments are what we used to train the model in our experiments.
The following are available arguments that can be set:
--model
: type of model to run- default:
INPREM
- choices:
CNN, RNN, RETAIN, DIPOLE, RNNplus, INPREM, INPREM_b, INPREM_s, INPREM_o
- default:
--emb_dim
: size of medical variable (or code) embedding- default: 256
--train
: boolean to train the model or use the pre-trained model- default:
False
- default:
--epochs
: number of iterations- default: 25
--batch-size
: batch size for data- default: 32
--drop_rate
: drop-out rate before each weight layer- default: 0.5
--optimizer
: optimizer for model- default:
Adam
- choices:
SGD, Adadelta, Adam
- default:
--lr
: learning rate for each step- default: 5e-4
--weight_decay
: weight decay for the model run- default: 1e-4
--save_model_dir
: directory to save the model with the best performance- default:
os.path.join(base_dir, 'saved_models')
- default:
--data_csv
: data file which will be used to train and evaluate a model- default:
os.path.join(base_dir, 'data', 'DIAGNOSES_ICD.csv')
- default:
--icd9map
: location for ICD9 code mapping to categories- default:
os.path.join(base_dir, 'data', 'icd9_map.json')
- default:
A sample command that can be run is:
python3 main.py --model=RNN --train
All models will run the evaluation step after training is complete. To evaluate a pre-trained model, run:
python3 main.py --model=RNN
A model evaluation will result in the following metrics:
- ROC AUC (Area Under the Receiver Operating Characteristic Curve)
- Visit-level precision @ k for k={5, 10, 15, 20, 25, 30}
- Code-level accuracy @ k for k={5, 10, 15, 20, 25, 30}
- Time taken to test the model
- Total time taken to run the entire script
Pre-trained models are avialable in the saved_models folder of this repository. You can specify a directory to load additional pre-trained models from by specifying the --save_model_dir
parameter flag.
All models were trained using the default parameter flags.
Model | ROC AUC | Time to Test [sec] |
Time to Run [sec] |
|
---|---|---|---|---|
Baselines | CNN | 0.9083 | 2.05 | 9.40 |
RNN | 0.8992 | 1.62 | 9.27 | |
RNN+ | 0.9142 | 1.58 | 9.21 | |
RETAIN | 0.8843 | 1.62 | 9.22 | |
Dipole | 0.9009 | 1.63 | 9.25 | |
INPREM | INPREM | 0.4918 | 2.87 | 10.31 |
INPREMb- | 0.6051 | 3.15 | 10.66 | |
INPREMo- | 0.5879 | 3.03 | 10.84 | |
INPREMs- | 0.6053 | 2.16 | 9.93 |
Model | Code-Level Accuracy@k | Visit-Level Precision@k | |||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
5 | 10 | 15 | 20 | 25 | 30 | 5 | 10 | 15 | 20 | 25 | 30 | ||
Baselines | CNN | 0.5266 | 0.5736 | 0.5736 | 0.5736 | 0.5736 | 0.5736 | 0.5280 | 0.5279 | 0.5869 | 0.6638 | 0.7320 | 0.7779 |
RNN | 0.4155 | 0.5363 | 0.5567 | 0.5582 | 0.5582 | 0.5582 | 0.6089 | 0.5758 | 0.6246 | 0.6874 | 0.7426 | 0.7834 | |
RNN+ | 0.6238 | 0.7094 | 0.7123 | 0.7123 | 0.7123 | 0.7123 | 0.6300 | 0.5931 | 0.6409 | 0.7053 | 0.7630 | 0.8071 | |
RETAIN | 0.5601 | 0.6481 | 0.6557 | 0.6557 | 0.6557 | 0.6557 | 0.6009 | 0.5661 | 0.6198 | 0.6858 | 0.7482 | 0.7919 | |
Dipole | 0.5635 | 0.5635 | 0.5635 | 0.5635 | 0.5635 | 0.5635 | 0.4772 | 0.4756 | 0.5442 | 0.6348 | 0.7028 | 0.7539 | |
INPREM | INPREM | 0.0249 | 0.0249 | 0.0249 | 0.0249 | 0.0249 | 0.0249 | 0.0102 | 0.0065 | 0.0059 | 0.0065 | 0.0067 | 0.0079 |
INPREMb- | 0.4815 | 0.5143 | 0.5143 | 0.5143 | 0.5143 | 0.5143 | 0.4146 | 0.2629 | 0.2321 | 0.2270 | 0.2264 | 0.2263 | |
INPREMo- | 0.5393 | 0.5409 | 0.5409 | 0.5409 | 0.5409 | 0.5409 | 0.3689 | 0.2228 | 0.1987 | 0.1948 | 0.1943 | 0.1945 | |
INPREMs- | 0.4817 | 0.5185 | 0.5185 | 0.5185 | 0.5185 | 0.5185 | 0.4117 | 0.2641 | 0.2337 | 0.2286 | 0.2280 | 0.2279 |