Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache hashes #977

Merged
merged 6 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ linters: with_defaults(
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
line_length_linter = line_length_linter(180)
line_length_linter = line_length_linter(180),
indentation_linter(indent = 2L, hanging_indent_style = "never")
)

2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ import(palmerpenguins)
import(paradox)
importFrom(R6,R6Class)
importFrom(R6,is.R6)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(future,nbrOfWorkers)
importFrom(future,plan)
importFrom(graphics,plot)
Expand Down
1 change: 1 addition & 0 deletions R/DataBackend.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ DataBackend = R6Class("DataBackend", cloneable = FALSE,
}
private$.hash = assert_string(rhs)
},

#' @template field_col_hashes
col_hashes = function() {
cn = setdiff(self$colnames, self$primary_key)
Expand Down
25 changes: 20 additions & 5 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@
#' prop.table(table(task$truth(r$train_set(1)))) # roughly same proportion
Resampling = R6Class("Resampling",
public = list(
#' @template field_id
id = NULL,

#' @template field_label
label = NULL,

Expand Down Expand Up @@ -126,7 +123,7 @@ Resampling = R6Class("Resampling",
#'
#' Note that this object is typically constructed via a derived classes, e.g. [ResamplingCV] or [ResamplingHoldout].
initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
self$id = assert_string(id, min.chars = 1L)
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$param_set = assert_param_set(param_set)
self$duplicated_ids = assert_flag(duplicated_ids)
Expand Down Expand Up @@ -186,6 +183,7 @@ Resampling = R6Class("Resampling",
instance = private$.combine(lapply(strata$row_id, private$.sample, task = task))
}

private$.hash = NULL
self$instance = instance
self$task_hash = task$hash
self$task_nrow = task$nrow
Expand Down Expand Up @@ -214,6 +212,16 @@ Resampling = R6Class("Resampling",
),

active = list(
#' @template field_id
id = function(rhs) {
if (missing(rhs)) {
return(private$.id)
}

private$.hash = NULL
private$.id = assert_string(rhs, min.chars = 1L)
},

#' @field is_instantiated (`logical(1)`)\cr
#' Is `TRUE` if the resampling has been instantiated.
is_instantiated = function(rhs) {
Expand All @@ -227,11 +235,18 @@ Resampling = R6Class("Resampling",
if (!self$is_instantiated) {
return(NA_character_)
}
calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))

if (is.null(private$.hash)) {
private$.hash = calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))
}

private$.hash
}
),

private = list(
.id = NULL,
.hash = NULL,
.groups = NULL,

.get_set = function(getter, i) {
Expand Down
77 changes: 57 additions & 20 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#' Instead, the methods first create a new [DataBackendDataTable] from the provided new data, and then
#' merge both backends into an abstract [DataBackend] which merges the results on-demand.
#' * `rename()` wraps the [DataBackend] of the Task in an additional [DataBackend] which deals with the renaming. Also updates `$col_roles` and `$col_info`.
#' * `set_levels()` updates the field `col_info()`.
#' * `set_levels()` and `droplevels()` `update the field `col_info()`.
#'
#' @template seealso_task
#' @concept Task
Expand All @@ -73,9 +73,6 @@
#' head(task)
Task = R6Class("Task",
public = list(
#' @template field_id
id = NULL,

#' @template field_label
label = NA_character_,

Expand Down Expand Up @@ -114,7 +111,7 @@ Task = R6Class("Task",
#'
#' Note that this object is typically constructed via a derived classes, e.g. [TaskClassif] or [TaskRegr].
initialize = function(id, task_type, backend, label = NA_character_, extra_args = list()) {
self$id = assert_string(id, min.chars = 1L)
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
if (!inherits(backend, "DataBackend")) {
Expand Down Expand Up @@ -175,7 +172,7 @@ Task = R6Class("Task",
catf("%s (%i x %i)%s", format(self), self$nrow, self$ncol,
if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))

roles = self$col_roles
roles = private$.col_roles
roles = roles[lengths(roles) > 0L]

# print additional columns as specified in reflections
Expand Down Expand Up @@ -204,7 +201,7 @@ Task = R6Class("Task",
catn(str_indent(sprintf("* %s:", str), roles[[role]]))
})

nrows = list(test = length(self$row_roles$test), holdout = length(self$row_roles$holdout))
nrows = list(test = length(private$.row_roles$test), holdout = length(private$.row_roles$holdout))
if (nrows$test || nrows$holdout) {
str = paste(c(
if(nrows$test) sprintf("%i (test)", nrows$test),
Expand Down Expand Up @@ -237,8 +234,8 @@ Task = R6Class("Task",
assert_choice(data_format, self$data_formats)
assert_flag(ordered)

row_roles = self$row_roles
col_roles = self$col_roles
row_roles = private$.row_roles
col_roles = private$.col_roles

if (is.null(rows)) {
rows = row_roles$use
Expand Down Expand Up @@ -368,6 +365,7 @@ Task = R6Class("Task",
filter = function(rows) {
assert_has_backend(self)
rows = assert_row_ids(rows)
private$.hash = NULL
private$.row_roles$use = intersect(private$.row_roles$use, rows)
invisible(self)
},
Expand All @@ -387,6 +385,8 @@ Task = R6Class("Task",
assert_has_backend(self)
assert_character(cols)
assert_subset(cols, private$.col_roles$feature)
private$.hash = NULL
private$.col_hashes = NULL
private$.col_roles$feature = intersect(private$.col_roles$feature, cols)
invisible(self)
},
Expand Down Expand Up @@ -452,7 +452,7 @@ Task = R6Class("Task",

# columns with these roles must be present in data
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order")
mandatory_cols = unlist(self$col_roles[mandatory_roles], use.names = FALSE)
mandatory_cols = unlist(private$.col_roles[mandatory_roles], use.names = FALSE)
missing_cols = setdiff(mandatory_cols, data$colnames)
if (length(missing_cols)) {
stopf("Cannot rbind data to task '%s', missing the following mandatory columns: %s", self$id, str_collapse(missing_cols))
Expand Down Expand Up @@ -484,9 +484,10 @@ Task = R6Class("Task",
tab[, c("type_y", "levels_y") := list(NULL, NULL)]

# everything looks good, modify task
private$.hash = NULL
self$backend = DataBackendRbind$new(self$backend, data)
self$col_info = tab[]
self$row_roles$use = c(self$row_roles$use, data$rownames)
private$.row_roles$use = c(private$.row_roles$use, data$rownames)

invisible(self)
},
Expand Down Expand Up @@ -535,7 +536,10 @@ Task = R6Class("Task",
setkeyv(self$col_info, "id")

# add new features
self$col_roles$feature = union(self$col_roles$feature, setdiff(data$colnames, c(pk, self$col_roles$target)))
private$.hash = NULL
private$.col_hashes = NULL
col_roles = private$.col_roles
private$.col_roles$feature = union(col_roles$feature, setdiff(data$colnames, c(pk, col_roles$target)))

# update backend
self$backend = DataBackendCbind$new(self$backend, data)
Expand All @@ -562,9 +566,11 @@ Task = R6Class("Task",
#' the object in its previous state.
rename = function(old, new) {
assert_has_backend(self)
private$.hash = NULL
private$.col_hashes = NULL
self$backend = DataBackendRename$new(self$backend, old, new)
setkeyv(self$col_info[old, ("id") := new, on = "id"], "id")
self$col_roles = map(self$col_roles, map_values, old = old, new = new)
private$.col_roles = map(private$.col_roles, map_values, old = old, new = new)
invisible(self)
},

Expand Down Expand Up @@ -593,7 +599,10 @@ Task = R6Class("Task",
set_row_roles = function(rows, roles = NULL, add_to = NULL, remove_from = NULL) {
assert_has_backend(self)
assert_subset(rows, self$backend$rownames)

private$.hash = NULL
private$.row_roles = task_set_roles(private$.row_roles, rows, roles, add_to, remove_from)

invisible(self)
},

Expand Down Expand Up @@ -622,8 +631,12 @@ Task = R6Class("Task",
set_col_roles = function(cols, roles = NULL, add_to = NULL, remove_from = NULL) {
assert_has_backend(self)
assert_subset(cols, self$col_info$id)

private$.hash = NULL
private$.col_hashes = NULL
new_roles = task_set_roles(private$.col_roles, cols, roles, add_to, remove_from)
private$.col_roles = task_check_col_roles(self, new_roles)

invisible(self)
},

Expand All @@ -646,6 +659,8 @@ Task = R6Class("Task",

tab = enframe(lapply(levels, unname), name = "id", value = "levels")
tab$fix_factor_levels = TRUE

private$.hash = NULL
self$col_info = ujoin(self$col_info, tab, key = "id")

invisible(self)
Expand All @@ -670,6 +685,7 @@ Task = R6Class("Task",
tab = tab[lengths(levels) > lengths(new_levels)]
tab[, c("levels", "fix_factor_levels") := list(Map(intersect, levels, new_levels), TRUE)]

private$.hash = NULL
self$col_info = ujoin(self$col_info, remove_named(tab, "new_levels"), key = "id")

invisible(self)
Expand Down Expand Up @@ -706,12 +722,27 @@ Task = R6Class("Task",
),

active = list(
#' @template field_id
id = function(rhs) {
if (missing(rhs)) {
return(private$.id)
}

private$.hash = NULL
private$.id = assert_string(rhs, min.chars = 1L)
},


#' @template field_hash
hash = function(rhs) {
private$.hash %??% calculate_hash(
class(self), self$id, self$backend$hash, self$col_info,
remove_named(private$.row_roles, "test"), private$.col_roles, private$.properties
)
if (is.null(private$.hash)) {
private$.hash = calculate_hash(
class(self), self$id, self$backend$hash, self$col_info,
remove_named(private$.row_roles, "test"), private$.col_roles, private$.properties
)
}

private$.hash
},

#' @field row_ids (`integer()`)\cr
Expand All @@ -728,7 +759,7 @@ Task = R6Class("Task",
#' * `"row_name"` (`character()`).
row_names = function(rhs) {
assert_ro_binding(rhs)
nn = self$col_roles$name
nn = private$.col_roles$name
if (length(nn) == 0L) {
return(NULL)
}
Expand Down Expand Up @@ -804,6 +835,7 @@ Task = R6Class("Task",
assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_row_roles, .var.name = "names of row_roles")
rhs = map(rhs, assert_row_ids, .var.name = "elements of row_roles")

private$.hash = NULL
private$.row_roles = rhs
},

Expand Down Expand Up @@ -835,6 +867,8 @@ Task = R6Class("Task",
assert_names(names(rhs), "unique", must.include = mlr_reflections$task_col_roles[[self$task_type]], .var.name = "names of col_roles")
assert_subset(unlist(rhs, use.names = FALSE), setdiff(self$col_info$id, self$backend$primary_key), .var.name = "elements of col_roles")

private$.hash = NULL
private$.col_hashes = NULL
private$.col_roles = task_check_col_roles(self, rhs)
},

Expand Down Expand Up @@ -982,11 +1016,15 @@ Task = R6Class("Task",

#' @template field_col_hashes
col_hashes = function() {
private$.col_hashes %??% self$backend$col_hashes[setdiff(unlist(self$col_roles), self$backend$primary_key)]
if (is.null(private$.col_hashes)) {
private$.col_hashes = self$backend$col_hashes[setdiff(unlist(private$.col_roles), self$backend$primary_key)]
}
private$.col_hashes
}
),

private = list(
.id = NULL,
.properties = NULL,
.col_roles = NULL,
.row_roles = NULL,
Expand All @@ -995,7 +1033,6 @@ Task = R6Class("Task",

deep_clone = function(name, value) {
# NB: DataBackends are never copied!
# TODO: check if we can assume col_info to be read-only
if (name == "col_info") copy(value) else value
}
)
Expand Down
8 changes: 4 additions & 4 deletions man/Resampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading