diff --git a/R/modelStudio.R b/R/modelStudio.R index 86e02b8..699e260 100644 --- a/R/modelStudio.R +++ b/R/modelStudio.R @@ -60,7 +60,7 @@ #' library("DALEX") #' library("modelStudio") #' -#' #:# ex1 classification on 'titanic_imputed' dataset +#' #:# ex1 classification on 'titanic' data #' #' # fit a model #' model_titanic <- glm(survived ~., data = titanic_imputed, family = "binomial") @@ -82,7 +82,7 @@ #' #' \donttest{ #' -#' #:# ex2 regression on 'apartments' dataset +#' #:# ex2 regression on 'apartments' data #' library("ranger") #' #' model_apartments <- ranger(m2.price ~. ,data = apartments) @@ -113,13 +113,11 @@ #' #:# ex3 xgboost model on 'HR' dataset #' library("xgboost") #' -#' # fit a model #' HR_matrix <- model.matrix(status == "fired" ~ . -1, HR) #' +#' # fit a model #' xgb_matrix <- xgb.DMatrix(HR_matrix, label = HR$status == "fired") -#' -#' params <- list(max_depth = 7, objective = "binary:logistic", eval_metric = "auc") -#' +#' params <- list(max_depth = 3, objective = "binary:logistic", eval_metric = "auc") #' model_HR <- xgb.train(params, xgb_matrix, nrounds = 300) #' #' # create an explainer for the model @@ -206,7 +204,7 @@ modelStudio.explainer <- function(explainer, if (show_info) { pb <- progress_bar$new( format = " Calculating :what \n Elapsed time: :elapsedfull ETA::eta", # :percent [:bar] - total = (3*B + 2 + 1)*obs_count + (B + 3*B + B) + 1, + total = (3*B + 2 + 1)*obs_count + (2*B + 3*B + B) + 1, show_after = 0 ) pb$tick(0, tokens = list(what = "...")) @@ -216,7 +214,7 @@ modelStudio.explainer <- function(explainer, fi <- calculate( ingredients::feature_importance( model, data, y, predict_function, variables = variable_names, B = B, N = 10*N), - "ingredients::feature_importance", show_info, pb, B) + "ingredients::feature_importance", show_info, pb, 2*B) which_numerical <- which_variables_are_numeric(data) diff --git a/README.md b/README.md index 7043c66..e9421dd 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ test_matrix <- model.matrix(survived ~.-1, test) # fit a model xgb_matrix <- xgb.DMatrix(train_matrix, label = train$survived) -params <- list(max_depth = 7, objective = "binary:logistic", eval_metric = "auc") +params <- list(max_depth = 3, objective = "binary:logistic", eval_metric = "auc") model <- xgb.train(params, xgb_matrix, nrounds = 500) # create an explainer for the model diff --git a/pkgdown/favicon/caret.html b/pkgdown/favicon/caret.html index 6272330..be5a7da 100644 --- a/pkgdown/favicon/caret.html +++ b/pkgdown/favicon/caret.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/demo.html b/pkgdown/favicon/demo.html index 893b78c..97d1e55 100644 --- a/pkgdown/favicon/demo.html +++ b/pkgdown/favicon/demo.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/h2o.html b/pkgdown/favicon/h2o.html index 396a5d2..aba3d5d 100644 --- a/pkgdown/favicon/h2o.html +++ b/pkgdown/favicon/h2o.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/keras.html b/pkgdown/favicon/keras.html index 540c194..159ee4a 100644 --- a/pkgdown/favicon/keras.html +++ b/pkgdown/favicon/keras.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/lightgbm.html b/pkgdown/favicon/lightgbm.html index 8d2ceee..f7baf64 100644 --- a/pkgdown/favicon/lightgbm.html +++ b/pkgdown/favicon/lightgbm.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/mlr.html b/pkgdown/favicon/mlr.html index 68304c1..54ac378 100644 --- a/pkgdown/favicon/mlr.html +++ b/pkgdown/favicon/mlr.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/mlr3.html b/pkgdown/favicon/mlr3.html index 09c03c9..9731952 100644 --- a/pkgdown/favicon/mlr3.html +++ b/pkgdown/favicon/mlr3.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/parsnip.html b/pkgdown/favicon/parsnip.html index b7ef3bc..1c1dec7 100644 --- a/pkgdown/favicon/parsnip.html +++ b/pkgdown/favicon/parsnip.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/scikitlearn.html b/pkgdown/favicon/scikitlearn.html index 7f6dbcc..249f48a 100644 --- a/pkgdown/favicon/scikitlearn.html +++ b/pkgdown/favicon/scikitlearn.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/pkgdown/favicon/xgboost.html b/pkgdown/favicon/xgboost.html index 35b1c35..13abb75 100644 --- a/pkgdown/favicon/xgboost.html +++ b/pkgdown/favicon/xgboost.html @@ -1787,9 +1787,9 @@
-
+
- - + + diff --git a/vignettes/ms-perks-features.Rmd b/vignettes/ms-perks-features.Rmd index bcdc0db..e433598 100644 --- a/vignettes/ms-perks-features.Rmd +++ b/vignettes/ms-perks-features.Rmd @@ -27,7 +27,7 @@ Let's use `HR` dataset to explore `modelStudio` parameters: ```{r results="hide"} train <- DALEX::HR -train$fired <- ifelse(train$status == "fired", 1, 0) +train$fired <- as.factor(ifelse(train$status == "fired", 1, 0)) train$status <- NULL head(train) @@ -42,7 +42,7 @@ Prepare `HR_test` data and a `ranger` model for the explainer: ```{r results="hide", eval = FALSE} # fit a ranger model library("ranger") -model <- ranger(fired ~., data = train) +model <- ranger(fired ~., data = train, probability = TRUE) # prepare validation dataset test <- DALEX::HR_test[1:1000,] @@ -189,7 +189,7 @@ modelStudio(explainer, -------------------------------------------------------------------- -## plot options +## additional options Customize some of `modelStudio` looks by overwriting default options returned by the `ms_options()` function. Full list of options: @@ -212,8 +212,44 @@ modelStudio(explainer, options = new_options) ``` +All visual options can be changed after the calculations using `ms_update_options()`. + +```{r eval = FALSE} +old_ms <- modelStudio(explainer) +old_ms + +# update the options +new_ms <- ms_update_options(old_ms, + time = 0, + facet_dim = c(1,2), + margin_left = 150) +new_ms +``` + ------------------------------------------------------------------- +## update observations + +Use `ms_update_observations()` to add more observations with their local explanations to the `modelStudio`. + +```{r eval = FALSE} +old_ms <- modelStudio(explainer) +old_ms + +# add new observations +plus_ms <- ms_update_observations(old_ms, + explainer, + new_observation = test[101:102,]) +plus_ms + +# overwrite old observations +new_ms <- ms_update_observations(old_ms, + explainer, + new_observation = test[103:104,], + overwrite = TRUE) +new_ms +``` + ## DALEXtra Use `explain_*()` functions from the [DALEXtra](https://github.com/ModelOriented/DALEXtra) @@ -225,10 +261,8 @@ library(DALEXtra) library(mlr) # fit a model -task <- makeRegrTask(id = "task", data = train, target = "fired") - -learner <- makeLearner("regr.ranger", predict.type = "response") - +task <- makeClassifTask(id = "task", data = train, target = "fired") +learner <- makeLearner("classif.ranger", predict.type = "prob") model <- train(learner, task) # create an explainer for the model @@ -238,8 +272,7 @@ explainer_mlr <- explain_mlr(model, label = "mlr") # make a studio for the model -modelStudio(explainer_mlr, - B = 10) +modelStudio(explainer_mlr) ``` ## References diff --git a/vignettes/ms-r-python-examples.Rmd b/vignettes/ms-r-python-examples.Rmd index 44c439d..d4d857e 100644 --- a/vignettes/ms-r-python-examples.Rmd +++ b/vignettes/ms-r-python-examples.Rmd @@ -95,9 +95,7 @@ train$survived <- as.factor(train$survived) # fit a model task <- TaskClassif$new(id = "titanic", backend = train, target = "survived") - learner <- lrn("classif.ranger", predict_type = "prob") - learner$train(task) # create an explainer for the model @@ -137,9 +135,7 @@ test_matrix <- model.matrix(survived ~.-1, test) # fit a model xgb_matrix <- xgb.DMatrix(train_matrix, label = train$survived) - -params <- list(max_depth = 7, objective = "binary:logistic", eval_metric = "auc") - +params <- list(max_depth = 3, objective = "binary:logistic", eval_metric = "auc") model <- xgb.train(params, xgb_matrix, nrounds = 500) # create an explainer for the model @@ -179,9 +175,8 @@ test <- data[-index,] train$survived <- as.factor(train$survived) # fit a model -cv <- trainControl(method = "repeatedcv", number = 3, repeats = 10) - -model <- train(survived ~ ., data = train, method = "gbm", trControl = cv) +cv <- trainControl(method = "repeatedcv", number = 3, repeats = 3) +model <- train(survived ~ ., data = train, method = "gbm", trControl = cv, verbose = FALSE) # create an explainer for the model explainer <- explain(model, @@ -212,6 +207,7 @@ data <- DALEX::titanic_imputed # init h2o h2o.init() +h2o.no_progress() # split the data h2o_split <- h2o.splitFrame(as.h2o(data)) @@ -223,12 +219,8 @@ train$survived <- as.factor(train$survived) # fit a model automl <- h2o.automl(y = "survived", training_frame = train, max_runtime_secs = 30) - model <- automl@leader -# stop h2o progress printing -h2o.no_progress() - # create an explainer for the model explainer <- explain_h2o(model, data = test, @@ -266,7 +258,7 @@ train <- data[index,] test <- data[-index,] # fit a model -rand_forest() %>% +model <- rand_forest() %>% set_engine("ranger", importance = "impurity") %>% set_mode("regression") %>% fit(m2.price ~ ., data = train)