WAE-DTI: Ensemble-based architecture for drug-target interaction prediction using descriptors and embeddings
Clone the project from GitHub and install the necessary dependencies.
git clone https://github.com/tariqshaban/wae-dti
cd wae-dti
pip install -r requirements.txt
Pytorch needs to be installed (preferably utilizing GPU acceleration "CUDA-enabled").
The program will attempt to clone the ESM repository if it is not present automatically. Either install Git on your machine or manually download the repository and place it in ../esm.
Tip
You can also clone the ESM repository beforehand by running the following command:
git clone https://github.com/facebookresearch/esm ../esm
├── config
│ └── config.py <- Store terminal arguments from entry files.
├── data
│ ├── embeddings
│ │ ├── drug_embedding <- Pre-trained drug embeddings (Parquet format).
│ │ └── target_embedding <- Pre-trained target embeddings (Parquet format).
│ └── raw
│ ├── classification <- Raw classification datasets for training and evaluation (Parquet format).
│ └── regression <- Raw regression datasets for training and evaluation (Parquet format).
├── results <- Store trained model, predictions, and metrics.
├── saved_models
│ ├── classification <- Trained models and their performance on the classification task.
│ └── regression <- Trained models and their performancee on the regression task.
├── src
│ ├── dti_dataset.py <- Data preprocessing and mounting prior to training.
│ ├── dti_model.py <- Neural network definition for the WAE-DTI architecture.
│ ├── ensemble.py <- Implementation of the weighted average ensemble method.
│ ├── evaluate.py <- Mathematical methods to evaluate model performance.
│ ├── predict.py <- Provide predictions of a model given a dataloader.
│ ├── test.py <- Generate and evaluate predictions of the model given external examples.
│ └── train.py <- WAE-DTI model trainer.
├── utils
│ ├── embedding
│ │ ├── extractor
│ │ │ ├── drug_extractor.py <- Extract and save embeddings from a list of drugs (Parquet format).
│ │ │ └── target_extractor.py <- Extract and save embeddings from a list of targets (Parquet format).
│ │ └── loader
│ │ ├── drug_loader.py <- Load drug embeddings from a Parquet file.
│ │ └── target_loader.py <- Load target embeddings from a Parquet file.
│ ├── eda.py <- Perform exploratory data analysis (EDA) on all datasets.
│ └── terminal_command_runner.py <- Run external terminal commands in real time.
├── entry_inference.py <- Entry point for inference.
├── entry_train.py <- Entry point for training.
├── README.md <- README file and documentation.
└── requirements.txt <- List of dependencies required to run the project.
entry_train.py [-h] [--use-wandb USE_WANDB] --task {classification,regression} [--dataset DATASET] [--seed SEED] [--epochs EPOCHS] [--patience PATIENCE]
[--eda EDA] [--learning-rate LEARNING_RATE] [--batch-size BATCH_SIZE] [--torch-device TORCH_DEVICE]
Argument | Description | Default | Notes |
---|---|---|---|
-h , --help |
Display a help message | ||
--use-wandb USE_WANDB |
Log training and validation metrics into W&B | False |
If set to True , the code will initialize a project named by the selected dataset |
--task |
Supervised learning algorithm type | Must be set to either classification or regression |
|
--dataset DATASET |
Dataset used for training | 'davis' |
The dataset must be in the data/raw path with train and test folders containing Parquet files |
--seed SEED |
Ensures reproducibility (on the same device) | ||
--epochs EPOCHS |
Number of training epochs | 3000 |
|
--patience PATIENCE |
Number of epochs to elapse without improvement for the training to stop | 200 |
|
--eda EDA |
Conduct a quick EDA on startup | False |
EDA is applied to all datasets regardless of the specified --dataset |
--learning-rate LEARNING_RATE |
Learning rate | 5e-4 |
|
--batch-size BATCH_SIZE |
Number of examples in each batch | 1024 |
|
--torch-device TORCH_DEVICE |
Device used for training (e.g. cuda:0, cpu) | If not specified, GPU will be utilized (if any) |
entry_inference.py [-h] --models-path MODELS_PATH --input-file INPUT_FILE
[--batch-size BATCH_SIZE] [--torch-device TORCH_DEVICE]
Argument | Description | Default | Notes |
---|---|---|---|
-h , --help |
Display a help message | ||
--models-path MODELS_PATH |
Folder path which contains the models trained on each drug fingerprint | You can use the pretrained models of any dataset within models/saved_models |
|
--input-file INPUT_FILE |
CSV file path containing "drug", "target", and "label" columns (label is optional) | If "label" column is specified, you must have enough examples to satisfy the concordance index calculation requirement | |
--task |
Supervised learning algorithm type | Must be set to either classification or regression |
|
--batch-size BATCH_SIZE |
Number of examples in each batch | 1024 |
|
--torch-device TORCH_DEVICE |
Device used for training (e.g. cuda:0, cpu) | If not specified, GPU will be utilized (if any) |
Note
When running entry_train.py
and entry_inference.py
, missing embeddings from drugs and targets are automatically
extracted and saved into data/embeddings
Important
During training, you may notice that tqdm
reports monotonically increasing total
value, this is caused by taking
into account the dynamic fluctuation of the early stopping counter. So, rather than always having the total value
equal to the number of epochs, the progress bar adjusts to display the smallest number of epochs needed to finish the
training, which translates to the following formula:
total = min(epochs, elapsed_epochs + patience - elapsed_patience)
The following tables are the result of training the model using nine drug descriptors and one target embedding (ESM-2) on six datasets while repeating the experiment five times. The values represent the mean and standard deviation of each metric.
Model | Davis | Kiba | ||||
---|---|---|---|---|---|---|
MSE | CI | R2 | MSE | CI | R2 | |
Atom pair fingerprint | 0.229 ± 0.006 | 0.892 ± 0.002 | 0.706 ± 0.007 | 0.155 ± 0.002 | 0.885 ± 0.001 | 0.751 ± 0.006 |
Avalon fingerprint | 0.213 ± 0.005 | 0.897 ± 0.004 | 0.718 ± 0.011 | 0.150 ± 0.002 | 0.883 ± 0.001 | 0.772 ± 0.006 |
MACCS keys fingerprint | 0.217 ± 0.002 | 0.895 ± 0.004 | 0.706 ± 0.007 | 0.171 ± 0.002 | 0.868 ± 0.003 | 0.732 ± 0.013 |
MH fingerprint | 0.217 ± 0.007 | 0.894 ± 0.002 | 0.709 ± 0.014 | 0.158 ± 0.002 | 0.883 ± 0.001 | 0.757 ± 0.014 |
Morgan fingerprint | 0.220 ± 0.005 | 0.895 ± 0.003 | 0.700 ± 0.011 | 0.157 ± 0.003 | 0.883 ± 0.001 | 0.753 ± 0.014 |
RDKit fingerprint | 0.222 ± 0.008 | 0.895 ± 0.002 | 0.711 ± 0.011 | 0.154 ± 0.001 | 0.885 ± 0.002 | 0.751 ± 0.010 |
SEC fingerprint | 0.219 ± 0.004 | 0.893 ± 0.004 | 0.711 ± 0.004 | 0.158 ± 0.002 | 0.883 ± 0.002 | 0.756 ± 0.009 |
Topological torsion fingerprint | 0.218 ± 0.007 | 0.896 ± 0.002 | 0.704 ± 0.012 | 0.158 ± 0.004 | 0.883 ± 0.003 | 0.747 ± 0.015 |
LDP | 0.310 ± 0.003 | 0.853 ± 0.003 | 0.597 ± 0.014 | 0.389 ± 0.003 | 0.756 ± 0.001 | 0.422 ± 0.007 |
Ensemble | 0.190 ± 0.001 | 0.915 ± 0.001 | 0.745 ± 0.003 | 0.127 ± 0.001 | 0.899 ± 0.000 | 0.778 ± 0.003 |
Model | DTC | Metz | ||||
---|---|---|---|---|---|---|
MSE | CI | R2 | MSE | CI | R2 | |
Atom pair fingerprint | 0.188 ± 0.004 | 0.883 ± 0.001 | 0.787 ± 0.009 | 0.357 ± 0.004 | 0.787 ± 0.002 | 0.570 ± 0.005 |
Avalon fingerprint | 0.177 ± 0.003 | 0.883 ± 0.003 | 0.812 ± 0.013 | 0.325 ± 0.009 | 0.798 ± 0.005 | 0.616 ± 0.005 |
MACCS keys fingerprint | 0.201 ± 0.003 | 0.867 ± 0.003 | 0.800 ± 0.004 | 0.337 ± 0.008 | 0.793 ± 0.003 | 0.610 ± 0.005 |
MH fingerprint | 0.193 ± 0.006 | 0.877 ± 0.004 | 0.781 ± 0.010 | 0.358 ± 0.008 | 0.787 ± 0.002 | 0.575 ± 0.016 |
Morgan fingerprint | 0.191 ± 0.004 | 0.880 ± 0.002 | 0.789 ± 0.015 | 0.365 ± 0.003 | 0.785 ± 0.001 | 0.561 ± 0.004 |
RDKit fingerprint | 0.179 ± 0.005 | 0.885 ± 0.002 | 0.806 ± 0.017 | 0.330 ± 0.003 | 0.796 ± 0.001 | 0.610 ± 0.008 |
SEC fingerprint | 0.191 ± 0.001 | 0.881 ± 0.001 | 0.789 ± 0.007 | 0.364 ± 0.005 | 0.785 ± 0.002 | 0.563 ± 0.009 |
Topological torsion fingerprint | 0.189 ± 0.003 | 0.880 ± 0.001 | 0.786 ± 0.007 | 0.363 ± 0.002 | 0.787 ± 0.001 | 0.563 ± 0.006 |
LDP | 0.434 ± 0.005 | 0.768 ± 0.001 | 0.565 ± 0.007 | 0.532 ± 0.009 | 0.723 ± 0.003 | 0.418 ± 0.011 |
Ensemble | 0.143 ± 0.001 | 0.898 ± 0.001 | 0.839 ± 0.008 | 0.284 ± 0.004 | 0.813 ± 0.001 | 0.676 ± 0.006 |
Model | ToxCast | STITCH | ||||
---|---|---|---|---|---|---|
MSE | CI | R2 | MSE | CI | R2 | |
Atom pair fingerprint | 0.326 ± 0.002 | 0.914 ± 0.002 | 0.553 ± 0.008 | 1.140 ± 0.003 | 0.751 ± 0.004 | 0.389 ± 0.004 |
Avalon fingerprint | 0.321 ± 0.003 | 0.916 ± 0.001 | 0.559 ± 0.009 | 1.078 ± 0.006 | 0.740 ± 0.006 | 0.417 ± 0.003 |
MACCS keys fingerprint | 0.324 ± 0.000 | 0.914 ± 0.001 | 0.561 ± 0.006 | 1.120 ± 0.006 | 0.708 ± 0.004 | 0.407 ± 0.002 |
MH fingerprint | very high MSE | very low CI | very low rm2 | very high MSE | very low CI | very low rm2 |
Morgan fingerprint | 0.333 ± 0.002 | 0.911 ± 0.002 | 0.540 ± 0.012 | 1.122 ± 0.006 | 0.763 ± 0.004 | 0.398 ± 0.005 |
RDKit fingerprint | 0.323 ± 0.002 | 0.915 ± 0.001 | 0.553 ± 0.005 | 1.143 ± 0.008 | 0.743 ± 0.005 | 0.389 ± 0.006 |
SEC fingerprint | 0.333 ± 0.003 | 0.911 ± 0.002 | 0.540 ± 0.009 | 1.167 ± 0.006 | 0.759 ± 0.005 | 0.374 ± 0.007 |
Topological torsion fingerprint | 0.331 ± 0.002 | 0.913 ± 0.001 | 0.541 ± 0.006 | 1.163 ± 0.003 | 0.754 ± 0.004 | 0.378 ± 0.004 |
LDP | 0.369 ± 0.001 | 0.898 ± 0.000 | 0.506 ± 0.002 | 1.542 ± 0.003 | 0.627 ± 0.002 | 0.201 ± 0.001 |
Ensemble | 0.308 ± 0.001 | 0.922 ± 0.000 | 0.581 ± 0.003 | 0.934 ± 0.004 | 0.772 ± 0.001 | 0.488 ± 0.003 |
In addition to the mentioned results for the regression task, three datasets were used to evaluate the model for the classification task.
Model | BioSNAP | ||||
---|---|---|---|---|---|
Sensitivity | Specificity | AUC | AUPRC | Threshold | |
Atom pair fingerprint | 0.824 ± 0.008 | 0.830 ± 0.019 | 0.899 ± 0.005 | 0.905 ± 0.002 | 0.560 ± 0.251 |
Avalon fingerprint | 0.859 ± 0.008 | 0.872 ± 0.005 | 0.932 ± 0.002 | 0.933 ± 0.001 | 0.582 ± 0.121 |
MACCS keys fingerprint | 0.873 ± 0.006 | 0.884 ± 0.006 | 0.942 ± 0.002 | 0.945 ± 0.002 | 0.557 ± 0.126 |
Morgan fingerprint | 0.789 ± 0.008 | 0.830 ± 0.004 | 0.880 ± 0.003 | 0.892 ± 0.004 | 0.701 ± 0.082 |
RDKit fingerprint | 0.841 ± 0.002 | 0.857 ± 0.009 | 0.915 ± 0.003 | 0.917 ± 0.003 | 0.670 ± 0.186 |
SEC fingerprint | 0.795 ± 0.006 | 0.826 ± 0.011 | 0.879 ± 0.001 | 0.888 ± 0.002 | 0.682 ± 0.053 |
Topological torsion fingerprint | 0.807 ± 0.014 | 0.834 ± 0.010 | 0.888 ± 0.004 | 0.896 ± 0.003 | 0.442 ± 0.098 |
LDP | 0.767 ± 0.010 | 0.805 ± 0.012 | 0.858 ± 0.003 | 0.862 ± 0.001 | 0.537 ± 0.047 |
Ensemble | 0.862 ± 0.003 | 0.871 ± 0.006 | 0.935 ± 0.001 | 0.943 ± 0.001 | 0.478 ± 0.018 |
Model | Davis | ||||
---|---|---|---|---|---|
Sensitivity | Specificity | AUC | AUPRC | Threshold | |
Atom pair fingerprint | 0.892 ± 0.009 | 0.853 ± 0.021 | 0.929 ± 0.005 | 0.408 ± 0.014 | 0.559 ± 0.168 |
Avalon fingerprint | 0.882 ± 0.013 | 0.847 ± 0.005 | 0.925 ± 0.005 | 0.431 ± 0.014 | 0.457 ± 0.082 |
MACCS keys fingerprint | 0.890 ± 0.006 | 0.862 ± 0.009 | 0.929 ± 0.003 | 0.415 ± 0.021 | 0.564 ± 0.119 |
Morgan fingerprint | 0.881 ± 0.019 | 0.855 ± 0.017 | 0.926 ± 0.003 | 0.390 ± 0.016 | 0.590 ± 0.289 |
RDKit fingerprint | 0.876 ± 0.020 | 0.844 ± 0.011 | 0.924 ± 0.004 | 0.416 ± 0.002 | 0.485 ± 0.093 |
SEC fingerprint | 0.879 ± 0.018 | 0.852 ± 0.009 | 0.926 ± 0.004 | 0.398 ± 0.022 | 0.561 ± 0.125 |
Topological torsion fingerprint | 0.874 ± 0.022 | 0.851 ± 0.007 | 0.919 ± 0.005 | 0.383 ± 0.009 | 0.618 ± 0.048 |
LDP | 0.797 ± 0.015 | 0.728 ± 0.014 | 0.827 ± 0.003 | 0.239 ± 0.003 | 0.537 ± 0.059 |
Ensemble | 0.900 ± 0.013 | 0.867 ± 0.010 | 0.936 ± 0.002 | 0.474 ± 0.011 | 0.528 ± 0.055 |
Model | BindingDB | ||||
---|---|---|---|---|---|
Sensitivity | Specificity | AUC | AUPRC | Threshold | |
Atom pair fingerprint | 0.869 ± 0.009 | 0.830 ± 0.015 | 0.909 ± 0.011 | 0.636 ± 0.014 | 0.481 ± 0.054 |
Avalon fingerprint | 0.883 ± 0.007 | 0.834 ± 0.012 | 0.918 ± 0.003 | 0.651 ± 0.005 | 0.457 ± 0.045 |
MACCS keys fingerprint | 0.886 ± 0.009 | 0.841 ± 0.006 | 0.923 ± 0.001 | 0.671 ± 0.004 | 0.372 ± 0.082 |
Morgan fingerprint | 0.831 ± 0.005 | 0.794 ± 0.003 | 0.879 ± 0.002 | 0.597 ± 0.001 | 0.468 ± 0.052 |
RDKit fingerprint | 0.885 ± 0.004 | 0.832 ± 0.007 | 0.914 ± 0.001 | 0.634 ± 0.007 | 0.490 ± 0.076 |
SEC fingerprint | 0.844 ± 0.012 | 0.795 ± 0.006 | 0.881 ± 0.004 | 0.598 ± 0.002 | 0.500 ± 0.049 |
Topological torsion fingerprint | 0.876 ± 0.015 | 0.822 ± 0.014 | 0.908 ± 0.013 | 0.615 ± 0.016 | 0.518 ± 0.056 |
LDP | 0.818 ± 0.013 | 0.788 ± 0.019 | 0.873 ± 0.006 | 0.511 ± 0.008 | 0.489 ± 0.057 |
Ensemble | 0.885 ± 0.006 | 0.853 ± 0.007 | 0.931 ± 0.001 | 0.707 ± 0.005 | 0.517 ± 0.022 |
Tariq Sha’ban, Ahmad M. Mustafa, Mostafa Z. Ali, WAE-DTI: Ensemble-based architecture for drug–target interaction prediction using descriptors and embeddings, Informatics in Medicine Unlocked, Volume 52, 2025, 101604, ISSN 2352-9148.
@article{SHABAN2025101604,
title = {WAE-DTI: Ensemble-based architecture for drug–target interaction prediction using descriptors and embeddings},
journal = {Informatics in Medicine Unlocked},
volume = {52},
pages = {101604},
year = {2025},
issn = {2352-9148},
doi = {https://doi.org/10.1016/j.imu.2024.101604},
url = {https://www.sciencedirect.com/science/article/pii/S2352914824001618},
author = {Tariq Sha’ban and Ahmad M. Mustafa and Mostafa Z. Ali},
}