Skip to content

Commit

Permalink
perf: use %chin%
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Dec 3, 2024
1 parent b4c4360 commit 843ff75
Show file tree
Hide file tree
Showing 22 changed files with 72 additions and 72 deletions.
2 changes: 1 addition & 1 deletion R/DataBackendRename.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
old = old[ii]
new = new[ii]

if (self$primary_key %in% old) {
if (self$primary_key %chin% old) {
stopf("Renaming the primary key is not supported")
}

Expand Down
4 changes: 2 additions & 2 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ calculate_cost = function(start_learner, learner, hotstart_id) {
cost = learner$param_set$values[[hotstart_id]] - start_learner$param_set$values[[hotstart_id]]
if (cost == 0) return(-1)

if ("hotstart_backward" %in% learner$properties && "hotstart_forward" %in% learner$properties) {
if ("hotstart_backward" %chin% learner$properties && "hotstart_forward" %chin% learner$properties) {
if (cost < 0) 0 else cost
} else if ("hotstart_backward" %in% learner$properties) {
} else if ("hotstart_backward" %chin% learner$properties) {
if (cost < 0) 0 else NA_real_
} else {
if (cost > 0) cost else NA_real_
Expand Down
8 changes: 4 additions & 4 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
pv = self$param_set$get_values(tags = "train")
pv$count_marshaling = pv$count_marshaling %??% FALSE
roll = function(name) {
name %in% names(pv) && pv[[name]] > runif(1L)
name %chin% names(pv) && pv[[name]] > runif(1L)
}

if (!is.null(pv$sleep_train)) {
Expand Down Expand Up @@ -248,7 +248,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
n = task$nrow
pv = self$param_set$get_values(tags = "predict")
roll = function(name) {
name %in% names(pv) && pv[[name]] > runif(1L)
name %chin% names(pv) && pv[[name]] > runif(1L)
}

if (!is.null(pv$sleep_predict)) {
Expand Down Expand Up @@ -281,7 +281,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
response = prob = NULL
missing_type = pv$predict_missing_type %??% "na"

if ("response" %in% self$predict_type) {
if ("response" %chin% self$predict_type) {
response = rep.int(unclass(model$response), n)
if (!is.null(pv$predict_missing)) {
ii = sample.int(n, n * pv$predict_missing)
Expand All @@ -292,7 +292,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
}
}

if ("prob" %in% self$predict_type) {
if ("prob" %chin% self$predict_type) {
cl = task$class_names
prob = matrix(runif(n * length(cl)), nrow = n)
prob = prob / rowSums(prob)
Expand Down
6 changes: 3 additions & 3 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
if ("weights" %chin% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

Expand All @@ -89,11 +89,11 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
newdata = task$data(cols = task$feature_names)
response = prob = NULL

if ("response" %in% self$predict_type) {
if ("response" %chin% self$predict_type) {
response = invoke(predict, self$model, newdata = newdata, type = "class",
.opts = allow_partial_matching, .args = pv)
response = unname(response)
} else if ("prob" %in% self$predict_type) {
} else if ("prob" %chin% self$predict_type) {
prob = invoke(predict, self$model, newdata = newdata, type = "prob",
.opts = allow_partial_matching, .args = pv)
rownames(prob) = NULL
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
if ("weights" %chin% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

Expand Down
10 changes: 5 additions & 5 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,19 @@ Measure = R6Class("Measure",
# check should be added to assert_measure()
# except when the checks are superfluous for rr$score() and bmr$score()
# these checks should be added bellow
if ("requires_task" %in% self$properties && is.null(task)) {
if ("requires_task" %chin% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}

if ("requires_learner" %in% self$properties && is.null(learner)) {
if ("requires_learner" %chin% self$properties && is.null(learner)) {
stopf("Measure '%s' requires a learner", self$id)
}

if (!is_scalar_na(self$task_type) && self$task_type != prediction$task_type) {
stopf("Measure '%s' incompatible with task type '%s'", self$id, prediction$task_type)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
if ("requires_train_set" %chin% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

Expand Down Expand Up @@ -258,7 +258,7 @@ Measure = R6Class("Measure",
#' @template field_predict_sets
predict_sets = function(rhs) {
if (!missing(rhs)) {
private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %in% self$properties)
private$.predict_sets = assert_subset(rhs, mlr_reflections$predict_sets, empty.ok = "requires_no_prediction" %chin% self$properties)
}
private$.predict_sets
},
Expand Down Expand Up @@ -385,7 +385,7 @@ score_single_measure = function(measure, task, learner, train_set, prediction) {
#' @noRd
score_measures = function(obj, measures, reassemble = TRUE, view = NULL, iters = NULL) {
reassemble_learners = reassemble ||
some(measures, function(m) any(c("requires_learner", "requires_model") %in% m$properties))
some(measures, function(m) any(c("requires_learner", "requires_model") %chin% m$properties))
tab = get_private(obj)$.data$as_data_table(view = view, reassemble_learners = reassemble_learners, convert_predictions = FALSE)

if (!is.null(iters)) {
Expand Down
2 changes: 1 addition & 1 deletion R/PredictionClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ PredictionClassif = R6Class("PredictionClassif", inherit = Prediction,
as.data.table.PredictionClassif = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "truth", "response")])

if ("prob" %in% x$predict_types) {
if ("prob" %chin% x$predict_types) {
prob = as.data.table(x$data$prob)
setnames(prob, names(prob), paste0("prob.", names(prob)))
tab = rcbind(tab, prob)
Expand Down
4 changes: 2 additions & 2 deletions R/PredictionDataClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ create_empty_prediction_data.TaskClassif = function(task, learner) {
truth = factor(character(), levels = cn)
)

if ("response" %in% predict_types) {
if ("response" %chin% predict_types) {
pdata$response = pdata$truth
}

if ("prob" %in% predict_types) {
if ("prob" %chin% predict_types) {
pdata$prob = matrix(numeric(), nrow = 0L, ncol = length(cn), dimnames = list(NULL, cn))
}

Expand Down
8 changes: 4 additions & 4 deletions R/PredictionDataRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ c.PredictionDataRegr = function(..., keep_duplicates = TRUE) { # nolint
result = as.list(tab)
result$quantiles = quantiles

if ("distr" %in% predict_types[[1L]]) {
if ("distr" %chin% predict_types[[1L]]) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
result$distr = do.call(c, map(dots, "distr"))
}
Expand Down Expand Up @@ -137,15 +137,15 @@ create_empty_prediction_data.TaskRegr = function(task, learner) {
truth = numeric()
)

if ("response" %in% predict_types) {
if ("response" %chin% predict_types) {
pdata$response = pdata$truth
}

if ("se" %in% predict_types) {
if ("se" %chin% predict_types) {
pdata$se = pdata$truth
}

if ("distr" %in% predict_types) {
if ("distr" %chin% predict_types) {
pdata$distr = list()
}

Expand Down
8 changes: 4 additions & 4 deletions R/PredictionRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
self$data = pdata
predict_types = intersect(names(mlr_reflections$learner_predict_types[["regr"]]), names(pdata))
# response is in saved in quantiles matrix
if ("quantiles" %in% predict_types) predict_types = union(predict_types, "response")
if ("quantiles" %chin% predict_types) predict_types = union(predict_types, "response")
self$predict_types = predict_types
if (is.null(pdata$response)) private$.quantile_response = attr(quantiles, "response")
}
Expand Down Expand Up @@ -94,7 +94,7 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
#' Access the stored vector distribution.
#' Requires package `distr6`(in repository \url{https://raphaels1.r-universe.dev}) .
distr = function() {
if ("distr" %in% self$predict_types) {
if ("distr" %chin% self$predict_types) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
}
return(self$data$distr)
Expand All @@ -111,12 +111,12 @@ PredictionRegr = R6Class("PredictionRegr", inherit = Prediction,
as.data.table.PredictionRegr = function(x, ...) { # nolint
tab = as.data.table(x$data[c("row_ids", "truth", "response", "se")])

if ("quantiles" %in% x$predict_types) {
if ("quantiles" %chin% x$predict_types) {
tab = rcbind(tab, as.data.table(x$data$quantiles))
set(tab, j = "response", value = x$response)
}

if ("distr" %in% x$predict_types) {
if ("distr" %chin% x$predict_types) {
require_namespaces("distr6", msg = "To predict probability distributions, please install %s")
tab$distr = list(x$distr)
}
Expand Down
18 changes: 9 additions & 9 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Task = R6Class("Task",

# print additional columns as specified in reflections
before = mlr_reflections$task_print_col_roles$before
iwalk(before[before %in% names(roles)], function(role, str) {
iwalk(before[before %chin% names(roles)], function(role, str) {
catn(str_indent(sprintf("* %s:", str), roles[[role]]))
})

Expand All @@ -246,7 +246,7 @@ Task = R6Class("Task",

# print additional columns are specified in reflections
after = mlr_reflections$task_print_col_roles$after
iwalk(after[after %in% names(roles)], function(role, str) {
iwalk(after[after %chin% names(roles)], function(role, str) {
catn(str_indent(sprintf("* %s:", str), roles[[role]]))
})

Expand Down Expand Up @@ -371,7 +371,7 @@ Task = R6Class("Task",
levels = function(cols = NULL) {
if (is.null(cols)) {
cols = unlist(private$.col_roles[c("target", "feature")], use.names = FALSE)
cols = self$col_info[get("id") %in% cols & get("type") %in% c("factor", "ordered"), "id", with = FALSE][[1L]]
cols = self$col_info[get("id") %chin% cols & get("type") %chin% c("factor", "ordered"), "id", with = FALSE][[1L]]
} else {
assert_subset(cols, self$col_info$id)
}
Expand Down Expand Up @@ -469,7 +469,7 @@ Task = R6Class("Task",
type_check = TRUE

if (is.data.frame(data)) {
pk_in_backend = pk %in% names(data)
pk_in_backend = pk %chin% names(data)
type_check = FALSE # done by auto-converter

keep_cols = intersect(names(data), self$col_info$id)
Expand Down Expand Up @@ -521,7 +521,7 @@ Task = R6Class("Task",
}

# merge factor levels
ii = tab[type %in% c("factor", "ordered"), which = TRUE]
ii = tab[type %chin% c("factor", "ordered"), which = TRUE]
for (i in ii) {
x = tab[["levels"]][[i]]
y = tab[["levels_y"]][[i]]
Expand Down Expand Up @@ -730,7 +730,7 @@ Task = R6Class("Task",
#' @return Modified `self`.
droplevels = function(cols = NULL) {
assert_has_backend(self)
tab = self$col_info[get("type") %in% c("factor", "ordered"), c("id", "levels", "fix_factor_levels"), with = FALSE]
tab = self$col_info[get("type") %chin% c("factor", "ordered"), c("id", "levels", "fix_factor_levels"), with = FALSE]
if (!is.null(cols)) {
tab = tab[list(cols), on = "id", nomatch = NULL]
}
Expand Down Expand Up @@ -930,7 +930,7 @@ Task = R6Class("Task",

assert_has_backend(self)
assert_list(rhs, .var.name = "row_roles")
if ("test" %in% names(rhs) || "holdout" %in% names(rhs)) {
if ("test" %chin% names(rhs) || "holdout" %chin% names(rhs)) {
stopf("Setting row roles 'test'/'holdout' is no longer possible.")
}
assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_row_roles, .var.name = "names of row_roles")
Expand Down Expand Up @@ -1337,7 +1337,7 @@ col_info = function(x, ...) {
#' @export
col_info.data.table = function(x, primary_key = character(), ...) { # nolint
types = map_chr(x, function(x) class(x)[1L])
discrete = setdiff(names(types)[types %in% c("factor", "ordered")], primary_key)
discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], primary_key)
levels = insert_named(named_list(names(types)), lapply(x[, discrete, with = FALSE], distinct_values, drop = FALSE))
data.table(id = names(types), type = unname(types), levels = levels, key = "id")
}
Expand All @@ -1346,7 +1346,7 @@ col_info.data.table = function(x, primary_key = character(), ...) { # nolint
#' @export
col_info.DataBackend = function(x, ...) { # nolint
types = map_chr(x$head(1L), function(x) class(x)[1L])
discrete = setdiff(names(types)[types %in% c("factor", "ordered")], x$primary_key)
discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], x$primary_key)
levels = insert_named(named_list(names(types)), x$distinct(rows = NULL, cols = discrete))
data.table(id = names(types), type = unname(types), levels = levels, key = "id")
}
Expand Down
6 changes: 3 additions & 3 deletions R/assertions.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ assert_predictable = function(task, learner) {
stopf("Learner '%s' has received tasks with different columns in train and predict.", learner$id)
}

ids = train_task$col_info[get("id") %in% cols_train, "id"]$id
ids = train_task$col_info[get("id") %chin% cols_train, "id"]$id
ci_predict = task$col_info[list(ids), c("id", "type", "levels"), on = "id"]
ci_train = train_task$col_info[list(ids), c("id", "type", "levels"), on = "id"]

Expand Down Expand Up @@ -260,11 +260,11 @@ assert_measure = function(measure, task = NULL, learner = NULL, prediction = NUL
#' @param prediction ([Prediction]).
#' @rdname mlr_assertions
assert_scorable = function(measure, task, learner, prediction = NULL, .var.name = vname(measure)) {
if ("requires_model" %in% measure$properties && is.null(learner$model)) {
if ("requires_model" %chin% measure$properties && is.null(learner$model)) {
stopf("Measure '%s' requires the trained model", measure$id)
}

if ("requires_model" %in% measure$properties && is_marshaled_model(learner$model)) {
if ("requires_model" %chin% measure$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", measure$id)
}

Expand Down
6 changes: 3 additions & 3 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps

setDT(design)
task = learner = resampling = NULL
if ("task" %in% clone) {
if ("task" %chin% clone) {
design[, "task" := list(list(task[[1L]]$clone())), by = list(hashes(task))]
}
if ("learner" %in% clone) {
if ("learner" %chin% clone) {
design[, "learner" := list(list(learner[[1L]]$clone())), by = list(hashes(learner))]
}
if ("resampling" %in% clone) {
if ("resampling" %chin% clone) {
design[, "resampling" := list(list(resampling[[1L]]$clone())), by = list(hashes(resampling))]
}

Expand Down
6 changes: 3 additions & 3 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@
#' print(bmr1$combine(bmr2))
resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) {
assert_subset(clone, c("task", "learner", "resampling"))
task = assert_task(as_task(task, clone = "task" %in% clone))
learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE))
resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %in% clone))
task = assert_task(as_task(task, clone = "task" %chin% clone))
learner = assert_learner(as_learner(learner, clone = "learner" %chin% clone, discard_state = TRUE))
resampling = assert_resampling(as_resampling(resampling, clone = "resampling" %chin% clone))
assert_flag(store_models)
assert_flag(store_backends)
# this does not check the internal validation task as it might not be set yet
Expand Down
2 changes: 1 addition & 1 deletion R/set_validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ set_validate = function(learner, validate, ...) {

#' @export
set_validate.Learner = function(learner, validate, ...) {
if (!"validation" %in% learner$properties) {
if (!"validation" %chin% learner$properties) {
stopf("Learner '%s' does not support validation.", learner$id)
}
learner$validate = validate
Expand Down
4 changes: 2 additions & 2 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL,
if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}
if ("internal_valid" %in% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) {
if ("internal_valid" %chin% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) {
stopf("Cannot set the predict_type field of learner '%s' to 'internal_valid' if there is no internal validation task configured", learner$id)
}

Expand Down Expand Up @@ -351,7 +351,7 @@ prediction_tasks_and_sets = function(task, train_result, validate, sets, predict
return(list(tasks = tasks[predict_sets], sets = sets[predict_sets]))
}

if ("internal_valid" %in% predict_sets) {
if ("internal_valid" %chin% predict_sets) {
if (is.numeric(validate) || identical(validate, "test")) {
# in this scenario, the internal_valid_task was created during learner_train, which means that it used the
# primary task. The selected ids are returned via the train result
Expand Down
Loading

0 comments on commit 843ff75

Please sign in to comment.