Skip to content

Commit

Permalink
Speed: Replace mutate_if() by mutate() (#37)
Browse files Browse the repository at this point in the history
* Format the lapply

* replace mutate_if() by mutate()
  • Loading branch information
mayer79 authored Mar 12, 2024
1 parent dcacb6b commit f7ec18c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 30 deletions.
24 changes: 17 additions & 7 deletions R/measure_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,15 @@ measure_importance.randomForest <- function(forest, mean_sample = "top_trees", m
if (is.null(forest$forest)) {
stop("Make sure forest has been saved when calling randomForest by randomForest(..., keep.forest = TRUE).")
}
forest_table <-
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
mutate_if(is.factor, as.character) %>%
calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist()
forest_table <- lapply(
1:forest$ntree,
function(i)
randomForest::getTree(forest, k = i, labelVar = TRUE) %>%
mutate(`split var` = as.character(`split var`)) %>%
calculate_tree_depth() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, `split var`) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
Expand Down Expand Up @@ -201,9 +206,14 @@ measure_importance.ranger <- function(forest, mean_sample = "top_trees", measure
importance_frame <- data.frame(variable = names(forest$variable.importance), stringsAsFactors = FALSE)
# Get objects necessary to calculate importance measures based on the tree structure
if(any(c("mean_min_depth", "no_of_nodes", "no_of_trees", "times_a_root", "p_value") %in% measures)){
forest_table <-
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
forest_table <- lapply(
1:forest$num.trees,
function(i)
ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
Expand Down
24 changes: 17 additions & 7 deletions R/min_depth_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,15 @@ min_depth_distribution <- function(forest){
#' @export
min_depth_distribution.randomForest <- function(forest){
tree <- NULL; `split var` <- NULL; depth <- NULL
forest_table <-
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
mutate_if(is.factor, as.character) %>%
calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist()
forest_table <- lapply(
1:forest$ntree,
function(i)
randomForest::getTree(forest, k = i, labelVar = TRUE) %>%
mutate(`split var` = as.character(`split var`)) %>%
calculate_tree_depth() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, `split var`) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
Expand All @@ -76,9 +81,14 @@ min_depth_distribution.randomForest <- function(forest){
#' @export
min_depth_distribution.ranger <- function(forest){
tree <- NULL; splitvarName <- NULL; depth <- NULL
forest_table <-
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
forest_table <- lapply(
1:forest$num.trees,
function(i)
ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, splitvarName) %>%
dplyr::summarize(min(depth))
colnames(min_depth_frame) <- c("tree", "variable", "minimal_depth")
Expand Down
52 changes: 36 additions & 16 deletions R/min_depth_interactions.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ conditional_depth_ranger <- function(frame, vars){
# randomForest
min_depth_interactions_values <- function(forest, vars){
`.` <- NULL; .SD <- NULL; tree <- NULL; `split var` <- NULL
interactions_frame <-
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
mutate_if(is.factor, as.character) %>%
calculate_tree_depth() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
data.table::rbindlist() %>% as.data.frame()
interactions_frame <- lapply(
1:forest$ntree,
function(i)
randomForest::getTree(forest, k = i, labelVar = TRUE) %>%
mutate(`split var` = as.character(`split var`)) %>%
calculate_tree_depth() %>%
cbind(., tree = i, number = 1:nrow(.))
) %>%
data.table::rbindlist() %>%
as.data.frame()
interactions_frame[vars] <- NA_real_
interactions_frame <-
data.table::as.data.table(interactions_frame)[, conditional_depth(as.data.frame(.SD), vars), by = tree] %>% as.data.frame()
Expand All @@ -89,10 +94,15 @@ min_depth_interactions_values <- function(forest, vars){
# ranger
min_depth_interactions_values_ranger <- function(forest, vars){
`.` <- NULL; .SD <- NULL; tree <- NULL; splitvarName <- NULL
interactions_frame <-
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>% cbind(., tree = i, number = 1:nrow(.))) %>%
data.table::rbindlist() %>% as.data.frame()
interactions_frame <- lapply(
1:forest$num.trees,
function(i)
ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>%
cbind(., tree = i, number = 1:nrow(.))
) %>%
data.table::rbindlist() %>%
as.data.frame()
interactions_frame[vars] <- NA_real_
interactions_frame <-
data.table::as.data.table(interactions_frame)[, conditional_depth_ranger(as.data.frame(.SD), vars), by = tree] %>% as.data.frame()
Expand Down Expand Up @@ -185,10 +195,15 @@ min_depth_interactions.randomForest <- function(forest, vars = important_variabl
)
interactions_frame <- merge(interactions_frame, occurrences)
interactions_frame$interaction <- paste(interactions_frame$root_variable, interactions_frame$variable, sep = ":")
forest_table <-
lapply(1:forest$ntree, function(i) randomForest::getTree(forest, k = i, labelVar = T) %>%
mutate_if(is.factor, as.character) %>%
calculate_tree_depth() %>% cbind(tree = i)) %>% rbindlist()
forest_table <- lapply(
1:forest$ntree,
function(i)
randomForest::getTree(forest, k = i, labelVar = TRUE) %>%
mutate(`split var` = as.character(`split var`)) %>%
calculate_tree_depth() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, variable = `split var`) %>%
dplyr::summarize(minimal_depth = min(depth))
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
Expand Down Expand Up @@ -238,9 +253,14 @@ min_depth_interactions.ranger <- function(forest, vars = important_variables(mea
occurrences <- tidyr::pivot_longer(occurrences, cols = -"variable", names_to = "root_variable", values_to = "occurrences")
interactions_frame <- merge(interactions_frame, occurrences)
interactions_frame$interaction <- paste(interactions_frame$root_variable, interactions_frame$variable, sep = ":")
forest_table <-
lapply(1:forest$num.trees, function(i) ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>% cbind(tree = i)) %>% rbindlist()
forest_table <- lapply(
1:forest$num.trees,
function(i)
ranger::treeInfo(forest, tree = i) %>%
calculate_tree_depth_ranger() %>%
cbind(tree = i)
) %>%
rbindlist()
min_depth_frame <- dplyr::group_by(forest_table, tree, variable = splitvarName) %>%
dplyr::summarize(minimal_depth = min(depth))
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
Expand Down

0 comments on commit f7ec18c

Please sign in to comment.