diff --git a/README.md b/README.md index 67d5440..15d506c 100644 --- a/README.md +++ b/README.md @@ -11,17 +11,46 @@ Website: [Nextflow Graph Machine Learning](https://jbris.github.io/nextflow-grap - [Nextflow Graph Machine Learning](#nextflow-graph-machine-learning) - [Table of contents](#table-of-contents) - [Introduction](#introduction) -- [The pipeline](#the-pipeline) +- [The Nextflow pipeline](#the-nextflow-pipeline) +- [Python Environment](#python-environment) + - [MLOps](#mlops) +- [ArangoDB](#arangodb) # Introduction -The purpose of this project is to provide a simple demonstration of how to construct a Nextflow pipeline, with MLOps integration, for performing gene regulatory network (GRN) reconstruction using graph neural networks (GNNs). +The purpose of this project is to provide a simple demonstration of how to construct a Nextflow pipeline, with MLOps integration, for performing gene regulatory network (GRN) reconstruction using graph neural networks (GNNs). In practice, GRN reconstruction is an unsupervised link prediction problem. -# The pipeline +[For developing GNNs, we use PyTorch Geometric.](https://pytorch-geometric.readthedocs.io/en/latest/) + +# The Nextflow pipeline + +[Nextflow has been included to orchestrate the GRN reconstruction pipeline.](https://www.nextflow.io/) The pipeline is composed of the following steps: 1. Exploratory data analysis: View the GRN and calculate some summary statistics. 2. Processing: Process the graph feature matrix and edge list. Remove the disconnected subgraph. 3. ArangoDB Importing: Import the graph into ArangoDB. -4. Train a graph neural network using SAGE convolutional layers. +4. GNN training: Train a GNN using SAGE convolutional layers. +5. GNN training: Train a variational autoencoder GNN, and save the neural embeddings. + +# Python Environment + +[Python dependencies are specified in this requirements.txt file.](services/python/requirements.txt). + +These dependencies are installed during the build process for the following Docker image: ghcr.io/jbris/nextflow-graph-machine-learning:1.0.0 + +Execute the following command to pull the image: *docker pull ghcr.io/jbris/nextflow-graph-machine-learning:1.0.0* + +## MLOps + +* [A Docker compose file has been provided to launch an MLOps stack.](docker-compose.yml) +* [See the .env file for Docker environment variables.](.env) +* [The docker_up.sh script can be executed to launch the Docker services.](scripts/docker_up.sh) +* [DVC is included for data version control.](https://dvc.org/) +* [MLFlow is available for experiment tracking.](https://mlflow.org/) +* [MinIO is available for storing experiment artifacts.](https://min.io/) + +# ArangoDB + +[This pipeline provides a simple demonstration for saving and retrieving graph data to ArangoDB, combined with NetworkX usage and integration.](https://docs.arangodb.com/3.11/data-science/adapters/arangodb-networkx-adapter/) diff --git a/bin/train_vae_gnn.py b/bin/train_vae_gnn.py new file mode 100644 index 0000000..45539d6 --- /dev/null +++ b/bin/train_vae_gnn.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python + +###################################### +# Imports +###################################### + +from adbnx_adapter import ADBNX_Adapter +from arango import ArangoClient +import hydra +import matplotlib.pyplot as plt +import mlflow +import networkx as nx +from omegaconf import DictConfig +from os.path import join as join_path +import pandas as pd +from pathlib import Path +import torch +from sklearn.decomposition import PCA +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.utils import from_networkx +import torch_geometric.transforms as T +from torch_geometric.nn import SAGEConv, VGAE + +###################################### +# Classes +###################################### + + +class VariationalGCNEncoder(torch.nn.Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + n_layers: int = 2, + normalize: bool = False, + bias: bool = True, + aggr: str = "mean", + ) -> None: + """ + SAGENet constructor. + + Args: + in_channels (int): + The number of input channels. + hidden_channels (int): + The number of hidden channels. + out_channels (int): + The number of output channels. + n_layers (int, optional): + The number of SAGE convolutional layers. Defaults to 5. + normalize (bool, optional): + Whether to apply normalisation. Defaults to False. + bias (bool, optional): + Whether to include the bias term. Defaults to True. + aggr (str, optional): + The tensor aggregation type. Defaults to "mean". + """ + super().__init__() + self.layers = nn.ModuleList() + self.conv1 = SAGEConv( + in_channels, hidden_channels, normalize=normalize, aggr=aggr, bias=bias + ) + + self.conv2 = SAGEConv( + hidden_channels, 2 * out_channels, normalize=normalize, aggr=aggr, bias=bias + ) + + self.conv_mu = SAGEConv( + 2 * out_channels, out_channels, normalize=normalize, aggr=aggr, bias=bias + ) + + self.conv_logstd = SAGEConv( + 2 * out_channels, out_channels, normalize=normalize, aggr=aggr, bias=bias + ) + + self.layers.append(self.conv1) + for _ in range(n_layers): + self.layers.append( + SAGEConv( + hidden_channels, + hidden_channels, + normalize=normalize, + aggr=aggr, + bias=bias, + ) + ) + + self.activation = F.leaky_relu + + def forward( + self, x: torch.Tensor, edge_index: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + The forward pass. + + Args: + x (torch.Tensor): + Input data. + edge_index (torch.Tensor): + The graph edge index. + + Returns: + tuple[torch.Tensor, torch.Tensor]: + The convolutional mean and log-standard deviation. + """ + for layer in self.layers: + x = layer(x, edge_index) + x = self.activation(x) + + x = self.conv2(x, edge_index) + return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index) + + +###################################### +# Functions +###################################### + + +def log_results( + tracking_uri: str, + experiment_prefix: str, + grn_name: str, + in_channels: int, + config: DictConfig, +) -> None: + """ + Log experiment results to the experiment tracker. + + Args: + tracking_uri (str): + The tracking URI. + experiment_prefix (str): + The experiment name prefix. + grn_name (str): + The name of the GRN. + in_channels (int): + The number of input channels. + config (DictConfig): + The pipeline configuration. + """ + mlflow.set_tracking_uri(tracking_uri) + experiment_name = f"{experiment_prefix}_train_vae_gnn" + existing_exp = mlflow.get_experiment_by_name(experiment_name) + if not existing_exp: + mlflow.create_experiment(experiment_name) + mlflow.set_experiment(experiment_name) + + mlflow.set_tag("grn", grn_name) + mlflow.set_tag("gnn", "VAE") + + mlflow.log_param("grn", grn_name) + mlflow.log_param("in_channels", in_channels) + + for k in config["gnn"]: + mlflow.log_param(k, config["gnn"][k]) + + +def get_graph( + db_host: str, + db_name: str, + db_username: str, + db_password: str, + collection: str, + feature_k: str = "expression", +) -> nx.Graph: + """ + Retrieve the graph from the database. + + Args: + db_host (str): + The database host. + db_name (str): + The database name. + db_username (str): + The database username. + db_password (str): + The database password. + collection (str): + The database collection. + feature_k (str): + The dictionary key for node features. + + Returns: + nx.Graph: + The retrieved graph. + """ + db = ArangoClient(hosts=db_host).db( + db_name, username=db_username, password=db_password + ) + adapter = ADBNX_Adapter(db) + db_G = adapter.arangodb_graph_to_networkx(collection) + db_G = nx.Graph(db_G) + db_G = nx.convert_node_labels_to_integers(db_G) + + G = nx.Graph() + G.add_edges_from(db_G.edges) + for node_id, node_features in list(db_G.nodes(data=True)): + features = list(node_features[feature_k].values()) + G.nodes[node_id][feature_k] = features + + return G + + +def get_split( + G: nx.Graph, num_val: float, num_test: float, device: torch.device +) -> tuple[nx.Graph, nx.Graph, nx.Graph]: + """ + Get train-validation-test split. + + Args: + G (nx.Graph): + The graph. + num_val (float): + The proportion of validation data. + num_test (float): + The proportion of testing data. + device (torch.device): + The training device. + + Returns: + tuple[nx.Graph, nx.Graph, nx.Graph]: + The train-validation-test split. + """ + transform = T.Compose( + [ + T.NormalizeFeatures(), + T.ToDevice(device), + T.RandomLinkSplit( + num_val=num_val, + num_test=num_test, + is_undirected=True, + add_negative_train_samples=False, + split_labels=True, + ), + ] + ) + + train_data, val_data, test_data = transform(G) + return train_data, val_data, test_data + + +def get_model_components( + lr: float, + in_channels: int, + hidden_channels: int, + out_channels: int, + device: torch.device, + n_layers: int, + normalize: bool, + bias: bool, + aggr: str, +) -> tuple: + """ + Get the components for training the model. + + Args: + lr (float): + The learning rate. + in_channels (int): + The number of input channels. + hidden_channels (int): + The number of hidden channels. + out_channels (int): + The number of output channels. + device (torch.device): + The training device. + n_layers (int): + The number of SAGE convolutional layers. + normalize (bool): + Whether to normalize the input tensors. + bias (bool): + Whether to include the bias term. + aggr (str): + The data aggregation method. + + Returns: + tuple: + The components for training the model. + """ + model = VGAE( + VariationalGCNEncoder( + in_channels, + hidden_channels, + out_channels, + n_layers, + normalize, + bias, + aggr, + ) + ).to(device) + optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, "max", factor=0.05 + ) + + return model, optimizer, scheduler + + +def train_model( + model: torch.nn.Module, + train_data: nx.Graph, + val_data: nx.Graph, + test_data: nx.Graph, + n_epochs: int, + optimizer: torch.nn.Module, + device: torch.device, + enable_tracking: bool, +) -> float: + """ + Train the graph neural network. + + Args: + model (torch.nn.Module): + The graph neural network. + train_data (nx.Graph): + The training data. + val_data (nx.Graph): + The validation data. + test_data (nx.Graph): + The testing data. + n_epochs (int): + The number of epochs. + optimizer (torch.nn.Module): + The model optimiser. + device (torch.device): + The training device. + enable_tracking (bool): + Whether to enable experiment tracking. + + Returns: + float: + The final area-under-curve score. + """ + + def train(): + model.train() + optimizer.zero_grad() + z = model.encode(train_data.expression, train_data.edge_index) + loss = model.recon_loss(z, train_data.pos_edge_label_index) + loss = loss + (1 / train_data.num_nodes) * model.kl_loss() + loss.backward() + optimizer.step() + return float(loss) + + @torch.no_grad() + def test(data): + model.eval() + z = model.encode(data.expression, data.edge_index) + return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index) + + for epoch in range(n_epochs): + loss = train() + val_auc, val_ap = test(val_data) + test_auc, test_ap = test(test_data) + + if epoch % int(n_epochs * 0.05) == 0: + if enable_tracking: + mlflow.log_metric("train_loss", loss, step=epoch) + mlflow.log_metric("val_auc", val_auc, step=epoch) + mlflow.log_metric("val_ap", val_ap, step=epoch) + mlflow.log_metric("test_auc", test_auc, step=epoch) + mlflow.log_metric("test_ap", test_ap, step=epoch) + + print( + f"Epoch: {epoch:03d}, loss {loss:.4f}", + f"Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}", + f"Test AUC: {test_auc:.4f}, Test AP: {test_ap:.4f}", + ) + + final_test_auc, final_test_ap = test(test_data) + print(f"Final Test AUC: {final_test_auc:.4f}, Final Test AP: {final_test_ap:.4f}") + + if enable_tracking: + mlflow.log_metric("final_test_auc", final_test_auc) + mlflow.log_metric("final_test_ap", final_test_ap) + + return model + + +def view_embeddings( + model: torch.nn.Module, data: nx.Graph, output_dir: str, enable_tracking: bool +) -> str: + """ + View the latent embeddings in 2D. + + Args: + model (torch.nn.Module): + The variational autoencoder. + data (nx.Graph): + The graph data. + output_dir (str): + The output directory for saving plots. + enable_tracking (bool): + Whether experiment tracking is enabled. + + Returns: + str: + The saved visualisation. + """ + embeddings = model.encode(data.expression, data.edge_index).detach().cpu().numpy() + transformer = PCA(n_components=2) + emb_transformed = pd.DataFrame( + transformer.fit_transform(embeddings), columns=["x", "y"] + ) + emb_transformed.plot.scatter("x", "y") + + Path(output_dir).mkdir(parents=True, exist_ok=True) + outfile = join_path(output_dir, "graph.png") + plt.savefig(outfile) + + if enable_tracking: + mlflow.log_artifact(outfile) + + return outfile + + +###################################### +# Main +###################################### + + +@hydra.main(version_base=None, config_path="../conf", config_name="config") +def main(config: DictConfig) -> None: + """ + The main entry point for the plotting pipeline. + + Args: + config (DictConfig): + The pipeline configuration. + """ + EXPERIMENT_PREFIX = config["experiment"]["name"] + + DATA_DIR = config["dir"]["data_dir"] + OUT_DIR = config["dir"]["out_dir"] + + GRN_NAME = config["grn"]["input_dir"] + + DB_HOST = config["db"]["host"] + DB_NAME = config["db"]["name"] + DB_USERNAME = config["db"]["username"] + DB_PASSWORD = config["db"]["password"] + + NUM_VAL = config["gnn"]["num_val"] + NUM_TEST = config["gnn"]["num_test"] + HIDDEN_CHANNELS = config["gnn"]["hidden_channels"] + OUT_CHANNELS = config["gnn"]["out_channels"] + LR = config["gnn"]["lr"] + N_EPOCHS = config["gnn"]["n_epochs"] + N_LAYERS = config["gnn"]["n_layers"] + NORMALIZE = config["gnn"]["normalize"] + BIAS = config["gnn"]["bias"] + AGGR = config["gnn"]["aggr"] + + TRACKING_URI = config["experiment_tracking"]["tracking_uri"] + ENABLE_TRACKING = config["experiment_tracking"]["enabled"] + + G = get_graph(DB_HOST, DB_NAME, DB_USERNAME, DB_PASSWORD, GRN_NAME) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + G = from_networkx(G) + + train_data, val_data, test_data = get_split(G, NUM_VAL, NUM_TEST, device) + + in_channels = G.expression.shape[1] + model, optimizer, scheduler = get_model_components( + LR, + in_channels, + HIDDEN_CHANNELS, + OUT_CHANNELS, + device, + N_LAYERS, + NORMALIZE, + BIAS, + AGGR, + ) + + if ENABLE_TRACKING: + log_results(TRACKING_URI, EXPERIMENT_PREFIX, GRN_NAME, in_channels, config) + + model = train_model( + model, + train_data, + val_data, + test_data, + N_EPOCHS, + optimizer, + device, + ENABLE_TRACKING, + ) + + output_dir = join_path(DATA_DIR, OUT_DIR, GRN_NAME, "vae_gnn") + view_embeddings(model, train_data, output_dir, ENABLE_TRACKING) + + if ENABLE_TRACKING: + mlflow.end_run() + + +if __name__ == "__main__": + main() diff --git a/conf/gnn/vae.yaml b/conf/gnn/vae.yaml new file mode 100644 index 0000000..40fb722 --- /dev/null +++ b/conf/gnn/vae.yaml @@ -0,0 +1,10 @@ +n_layers: 3 +normalize: false +bias: true +aggr: mean +num_val: 0.05 +num_test: 0.05 +hidden_channels: 128 +out_channels: 128 +lr: 0.001 +n_epochs: 250 \ No newline at end of file diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..aab52d9 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1 @@ +*.png \ No newline at end of file diff --git a/docs/source/pipelines/index.rst b/docs/source/pipelines/index.rst index b50266b..b2aeadc 100644 --- a/docs/source/pipelines/index.rst +++ b/docs/source/pipelines/index.rst @@ -8,4 +8,5 @@ Nextflow Graph Machine Learning Pipelines eda.rst process.rst to_db.rst - train_gnn.rst \ No newline at end of file + train_gnn.rst + train_vae_gnn.rst \ No newline at end of file diff --git a/docs/source/pipelines/train_vae_gnn.rst b/docs/source/pipelines/train_vae_gnn.rst new file mode 100644 index 0000000..64d4ba3 --- /dev/null +++ b/docs/source/pipelines/train_vae_gnn.rst @@ -0,0 +1,7 @@ +Train a Variational Autoencoder Graph Neural Network +===================================================== + +*Date published:* |today| + +.. automodule:: bin.train_vae_gnn + :members: \ No newline at end of file diff --git a/gnn_pipeline.nf b/gnn_pipeline.nf index 207ba9c..8855e08 100644 --- a/gnn_pipeline.nf +++ b/gnn_pipeline.nf @@ -7,7 +7,8 @@ nextflow.enable.dsl=2 // ################################################################## include { toDb } from './modules/db.nf' -include { trainSAGE } from './modules/gnn.nf' +include { trainSAGE as SAGE } from './modules/gnn.nf' +include { trainVAE as VAE } from './modules/gnn.nf' include { dvcRepro as dvc } from './modules/mlops.nf' // ################################################################## @@ -22,10 +23,17 @@ params.featureMatrix = "expression_data.csv" // Workflow // ################################################################## +/** +* In practice, we should parallelise the training of the GraphSAGE and VAE +* nueral networks - which is what we'd do in an HPC environment. +* This would also introduce async channels and observables. +* But I'm running this on my laptop. +*/ workflow { processedDir = dvc(params.grn, projectDir) (db_log, grn_db) = toDb(params.grn, projectDir, processedDir, params.featureMatrix, params.edgeList) - gnn_res = trainSAGE(grn_db) + (gnn_log, gnn_db) = SAGE(grn_db) + VAE(gnn_db) } workflow.onComplete { diff --git a/modules/gnn.nf b/modules/gnn.nf index 8bdd993..af9af5f 100644 --- a/modules/gnn.nf +++ b/modules/gnn.nf @@ -17,4 +17,21 @@ process trainSAGE { """ train_gnn.py grn.input_dir="$grn" db.password=$ARANGO_ROOT_PASSWORD > train_gnn.log """ +} + +process trainVAE { + tag "Gene regulatory network: $grn" + + input: + val grn + + output: + path 'train_vae_gnn.log' + val grn_db + + script: + grn_db = grn + """ + train_vae_gnn.py grn.input_dir="$grn" db.password=$ARANGO_ROOT_PASSWORD gnn=vae > train_vae_gnn.log + """ } \ No newline at end of file