Skip to content

katieshin/CS598DL4H_Project

Repository files navigation

CS598DL4H Project: Reproducing INPREM

Project Details

As part of UIUC CS598 Deep Learning for Healthcare course, we have decided to reproduce INPREM: An Interpretable and Trustworthy Predictive Model for Healthcare. The goal of this project is to reproduce the experiment results within the paper and also add an ablation of our choosing to see how it influences the outcome.

Paper Details

The paper implements the INPREM model (and its variations) and compares it to currently widely-used models, such as:

  • CNN
  • RNN
  • RNN+
  • Dipole
  • RETAIN

The goal of the INPREM model is to be used for clinical prediction tasks, which current models are not very suitable for. Data from MIMIC-III is used, specifically DIAGNOSES_ICD.csv. INPREM was provided to us by one of the authors, and we implemented the other baseline models in Python. We compared Code-Level Accuracy and Visit-Level Precision over 30 visits in 5 visit increments.

Requirements

Create a python 3.8 environment.

To install requirements:

pip install -r requirements.txt

(Optional) For Windows and Linux users, to enable the models to run on your GPU the appropriate version of the torch must be installed.

pip3 install torch --extra-index-url https://download.pytorch.org/whl/cu113

Training

When running main.py to train the models you must use the --train argument along with the model specification, --model {model}. Replace {model} with one of the available models. The default arguments are what we used to train the model in our experiments.

The following are available arguments that can be set:

  • --model: type of model to run
    • default: INPREM
    • choices: CNN, RNN, RETAIN, DIPOLE, RNNplus, INPREM, INPREM_b, INPREM_s, INPREM_o
  • --emb_dim: size of medical variable (or code) embedding
    • default: 256
  • --train: boolean to train the model or use the pre-trained model
    • default: False
  • --epochs: number of iterations
    • default: 25
  • --batch-size: batch size for data
    • default: 32
  • --drop_rate: drop-out rate before each weight layer
    • default: 0.5
  • --optimizer: optimizer for model
    • default: Adam
    • choices: SGD, Adadelta, Adam
  • --lr: learning rate for each step
    • default: 5e-4
  • --weight_decay: weight decay for the model run
    • default: 1e-4
  • --save_model_dir: directory to save the model with the best performance
    • default: os.path.join(base_dir, 'saved_models')
  • --data_csv: data file which will be used to train and evaluate a model
    • default: os.path.join(base_dir, 'data', 'DIAGNOSES_ICD.csv')
  • --icd9map: location for ICD9 code mapping to categories
    • default: os.path.join(base_dir, 'data', 'icd9_map.json')

A sample command that can be run is:

python3 main.py --model=RNN --train

Evaluation

All models will run the evaluation step after training is complete. To evaluate a pre-trained model, run:

python3 main.py --model=RNN

A model evaluation will result in the following metrics:

  • ROC AUC (Area Under the Receiver Operating Characteristic Curve)
  • Visit-level precision @ k for k={5, 10, 15, 20, 25, 30}
  • Code-level accuracy @ k for k={5, 10, 15, 20, 25, 30}
  • Time taken to test the model
  • Total time taken to run the entire script

Pre-trained Models

Pre-trained models are avialable in the saved_models folder of this repository. You can specify a directory to load additional pre-trained models from by specifying the --save_model_dir parameter flag.

All models were trained using the default parameter flags.

Results

Model ROC AUC Time to
Test [sec]
Time to
Run [sec]
Baselines CNN 0.9083 2.05 9.40
RNN 0.8992 1.62 9.27
RNN+ 0.9142 1.58 9.21
RETAIN 0.8843 1.62 9.22
Dipole 0.9009 1.63 9.25
INPREM INPREM 0.4918 2.87 10.31
INPREMb- 0.6051 3.15 10.66
INPREMo- 0.5879 3.03 10.84
INPREMs- 0.6053 2.16 9.93
Model Code-Level Accuracy@k Visit-Level Precision@k
5 10 15 20 25 30 5 10 15 20 25 30
Baselines CNN 0.5266 0.5736 0.5736 0.5736 0.5736 0.5736 0.5280 0.5279 0.5869 0.6638 0.7320 0.7779
RNN0.4155 0.5363 0.5567 0.5582 0.5582 0.5582 0.6089 0.5758 0.6246 0.6874 0.7426 0.7834
RNN+ 0.6238 0.7094 0.7123 0.7123 0.7123 0.7123 0.6300 0.5931 0.6409 0.7053 0.7630 0.8071
RETAIN 0.5601 0.6481 0.6557 0.6557 0.6557 0.6557 0.6009 0.5661 0.6198 0.6858 0.7482 0.7919
Dipole 0.5635 0.5635 0.5635 0.5635 0.5635 0.5635 0.4772 0.4756 0.5442 0.6348 0.7028 0.7539
INPREM INPREM 0.0249 0.0249 0.0249 0.0249 0.0249 0.0249 0.0102 0.0065 0.0059 0.0065 0.0067 0.0079
INPREMb- 0.4815 0.5143 0.5143 0.5143 0.5143 0.5143 0.4146 0.2629 0.2321 0.2270 0.2264 0.2263
INPREMo- 0.5393 0.5409 0.5409 0.5409 0.5409 0.5409 0.3689 0.2228 0.1987 0.1948 0.1943 0.1945
INPREMs- 0.4817 0.5185 0.5185 0.5185 0.5185 0.5185 0.4117 0.2641 0.2337 0.2286 0.2280 0.2279

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages