forked from microbiome/OMA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path40_machine_learning.Rmd
151 lines (115 loc) · 4.44 KB
/
40_machine_learning.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Machine learning {#machine_learning}
```{r setup, echo=FALSE, results="asis"}
library(rebook)
chapterPreamble()
```
Machine learning (ML) is a part of artificial intelligence. There are multiple
definitions, but "machine" refers to computation and "learning" to improving
performance based on the data by finding patterns from it. Machine learning
includes wide variety of methods from simple statistical methods to more
complex methods such as neural-networks.
Machine learning can be divided into supervised and unsupervised machine learning.
Supervised ML is used to predict outcome based on the data. Unsupervised ML is used,
for example, to reduce dimensionality (e.g. PCA) and to find clusters from the
data (e.g., k-means clustering).
## Supervised machine learning
"Supervised" means that the training data is introduced before. The training data
contains labels (e.g., patient status), and the model is fitted based on the
training data. After fitting, the model is utilized to predict labels of data whose
labels are not known.
```{r superML1}
library(mia)
# Load experimental data
data(peerj13075)
(tse <- peerj13075)
```
Let's first preprocess the data.
```{r super2}
# Agglomerate data
tse <- agglomerateByRank(tse, rank = "order")
# Apply CLR transform
tse <- transformSamples(tse, method = "relabundance", pseudocount = 1)
tse <- transformSamples(tse, assay_name = "relabundance", method = "clr")
# Get assay
assay <- assay(tse, "clr")
# Transpose assay
assay <- t(assay)
# Convert into data.frame
df <- as.data.frame(assay)
# Add labels to assay
labels <- colData(tse)$Diet
labels <- as.factor(labels)
df$diet <- labels
df[5, 5]
```
In the example below, we use [mikropml](https://journals.asm.org/doi/10.1128/mBio.00434-20)
package. We try to predict the diet type based on the data.
```{r super3}
if( !require("mikropml") ){
install.packages("mikropml")
library(mikropml)
}
# Run random forest
results <- run_ml(df, "rf", outcome_colname = 'diet',
kfold = 2, cv_times = 5, training_frac = 0.8)
# Print result
confusionMatrix(data = results$trained_model$finalModel$predicted,
reference = results$trained_model$finalModel$y)
```
mikropml offers easier interface to [caret](https://cran.r-project.org/web/packages/caret/index.html)
package. However, we can also use it directly.
Let's use xgboost model which is another commonly used algorithm in bioinformatics.
```{r super4}
# Set seed for reproducibility
set.seed(6358)
# Specify train control
train_control <- trainControl(method = "cv", number = 5,
classProbs = TRUE,
savePredictions = "final",
allowParallel = TRUE)
# Specify hyperparameter tuning grid
tune_grid <- expand.grid(nrounds = c(50, 100, 200),
max_depth = c(6, 8, 10),
colsample_bytree = c(0.6, 0.8, 1),
eta = c(0.1, 0.3),
gamma = 0,
min_child_weight = c(3, 4, 5),
subsample = c(0.6, 0.8)
)
# Train the model, use LOOCV to evaluate performance
model <- train(x = assay,
y = labels,
method = "xgbTree",
objective = "binary:logistic",
trControl = train_control,
tuneGrid = tune_grid,
metric = "AUC",
verbosity = 0
)
```
Let's create ROC curve which is a commonly used method in binary classification.
For unbalanced data, you might want to plot precision-recall curve.
```{r super5}
if( !require(MLeval) ){
install.packages("MLeval")
library(MLeval)
}
# Calculate different evaluation metrics
res <- evalm(model, showplots = FALSE)
# Use patchwork to plot ROC and precision-recall curve side-by-side
library(patchwork)
res$roc + res$proc +
plot_layout(guides = "collect") & theme(legend.position = 'bottom')
```
## Unsupervised machine learning
"Unsupervised" means that the labels (e.g., patient status is not known),
and patterns are learned based only the abundance table, for instance.
Unsupervised ML is also known as a data mining where patterns are extracted
from big datasets.
For unsupervised machine learning, please refer to chapters that are listed below:
- Chapter \@ref(clustering)
- Chapter \@ref(community-similarity)
## Session Info {-}
```{r sessionInfo, echo=FALSE, results='asis'}
prettySessionInfo()
```