Skip to content

Commit

Permalink
add validation split option
Browse files Browse the repository at this point in the history
  • Loading branch information
jvpoulos committed Oct 14, 2022
1 parent a330f53 commit 88e72df
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion R/Lrnr_gru_keras.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#' be applied at given stages of the training procedure. Default callback
#' function \code{callback_early_stopping} stops training if the validation
#' loss does not improve across \code{patience} number of epochs.
#' - \code{validation_split}: Fraction of the training data to be used as validation data. Default is 0 (no validation).
#' - \code{...}: Other parameters passed to \code{\link[keras]{keras}}.
#'
#' @examples
Expand Down Expand Up @@ -74,7 +75,7 @@
#' valid_task <- validation(task, fold = task$folds[[1]])
#'
#' # instantiate learner, then fit and predict (simplifed example)
#' gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200)
#' gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200, validation_split=0.2)
#' gru_fit <- gru_lrnr$train(train_task)
#' gru_preds <- gru_fit$predict(valid_task)
#' }
Expand All @@ -95,6 +96,7 @@ Lrnr_gru_keras <- R6Class(
callbacks = list(
keras::callback_early_stopping(patience = 10)
),
validation_split=0,
...) {
params <- args_to_list()
super$initialize(params = params, ...)
Expand Down Expand Up @@ -187,6 +189,7 @@ Lrnr_gru_keras <- R6Class(
batch_size = args$batch_size,
epochs = args$epochs,
callbacks = args$callbacks,
validation_split= args$validation_split,
verbose = verbose,
shuffle = FALSE
)
Expand Down
5 changes: 4 additions & 1 deletion R/Lrnr_lstm_keras.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#' be applied at given stages of the training procedure. Default callback
#' function \code{callback_early_stopping} stops training if the validation
#' loss does not improve across \code{patience} number of epochs.
#' - \code{validation_split}: Fraction of the training data to be used as validation data. Default is 0 (no validation).
#' - \code{...}: Other parameters passed to \code{\link[keras]{keras}}.
#'
#' @examples
Expand Down Expand Up @@ -72,7 +73,7 @@
#' valid_task <- validation(task, fold = task$folds[[1]])
#'
#' # instantiate learner, then fit and predict (simplifed example)
#' lstm_lrnr <- Lrnr_lstm_keras$new(batch_size = 1, epochs = 200)
#' lstm_lrnr <- Lrnr_lstm_keras$new(batch_size = 1, epochs = 200, validation_split=0.2)
#' lstm_fit <- lstm_lrnr$train(train_task)
#' lstm_preds <- lstm_fit$predict(valid_task)
#' }
Expand All @@ -93,6 +94,7 @@ Lrnr_lstm_keras <- R6Class(
lr = 0.001,
layers = 1,
callbacks = list(keras::callback_early_stopping(patience = 10)),
validation_split=0,
...) {
params <- args_to_list()
super$initialize(params = params, ...)
Expand Down Expand Up @@ -185,6 +187,7 @@ Lrnr_lstm_keras <- R6Class(
batch_size = args$batch_size,
epochs = args$epochs,
callbacks = args$callbacks,
validation_split= args$validation_split,
verbose = verbose,
shuffle = FALSE
)
Expand Down

0 comments on commit 88e72df

Please sign in to comment.