forked from ed-wilkes/predictive-modelling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotConfMatrix.R
62 lines (58 loc) · 2.18 KB
/
plotConfMatrix.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#' @name plotConfMatrix
#' @author Ed Wilkes
#'
#' @description Plots confusion matrix 'predictions_class' object
#'
#' @param conf_mat Data frame containing confusion matrix to be plotted
#' @param breaks Breaks for heatmap colours, defaults to seq(0, 100, 5)
#' @param colours Colours for heatmap, defaults to c("white", "red2")
#'
#' @return gplots heatmap object
#'
plotConfMatrix <- function(data
,breaks = seq(0, 100, 5)
,colours = c("white", "red2")
,order = NULL) {
## Required packages
require(caret)
require(dplyr)
require(gplots)
require(stringr)
require(reshape2)
## Manipulate data for plotting
data <- data %>%
mutate(rep = substr(fold, str_locate(fold, "\\.") + 1, nchar(fold)))
data_split <- split(data, data$rep)
list_results <- lapply(data_split, function(x) {
conf_mat <- as.data.frame(confusionMatrix(data = x$pred
,reference = x$obs)$table)
return(conf_mat)
})
## Average confusion matrices for each repeat
df_all <- bind_rows(list_results) %>%
group_by(Reference, Prediction) %>%
summarise(Freq = round(mean(Freq), 0))
conf_mat <- reshape2::dcast(df_all, Reference ~ Prediction, value.var = "Freq")
## Plot resulting data
rownames(conf_mat) <- conf_mat[,1]
conf_mat <- conf_mat[,-1]
conf_mat <- as.matrix(conf_mat)
conf_mat_ratio <- sweep(conf_mat, 1, rowSums(conf_mat), `/`)*100
if (!is.null(order)) {
conf_mat_ratio <- as.matrix(conf_mat_ratio[order, order])
conf_mat <- as.matrix(conf_mat[order, order])
}
hm_breaks <- breaks
hm_colours <- colorRampPalette(colours)(length(hm_breaks)-1)
hm_pred <- heatmap.2(conf_mat_ratio
,trace = "none"
,dendrogram = "none"
,Colv = FALSE
,Rowv = FALSE
,col = hm_colours
,breaks = hm_breaks
,cellnote = conf_mat
,notecol = "black"
)
return(hm_pred)
}