forked from ModelOriented/DALEX-docs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01-architecture.Rmd
116 lines (80 loc) · 6.25 KB
/
01-architecture.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Architecture of DALEX {#architecture}
DALEX's architecture is simple and consistent. Actually, there are only three rules that should be remembered while using this tool.
* First - use the `explain()` function to enrich a black-box model with additional metadata required by explainers. Various explainers require various metadata. You may find their list in Section \@ref(explainFunction).
* Second - use the explainer function that calculates required descriptions. Consecutive explainers are introduced in Chapters \@ref(modelUnderstanding) and \@ref(predictionUnderstanding).
* Third - use generic `print()` or `plot()` function to see the explainer. Both functions work for one or more models.
These three steps are presented in Figure 2.1.
<p><span class="marginnote">Figure 2.1. The overview of DALEX's architecture. <br/><br/>
*A)* Any predictive model with defined input $x$ and output $y_{raw} \in \mathcal R$ may be used. <br/><br/>
*B)* Models are first enriched with additional metadata, such as a function that calculates predictions, validation data, model label or other components. The `explain()` function creates an object belonging to the `explainer` class that is used in further processing.<br/><br/>
*C)* Specialized explainers calculate numerical summaries that can be plotted with generic `plot()` function.
</span>
<img src="images/architecture.png"/></p>
## The `explain()` function {#explainFunction}
DALEX is designed to work with various black-box models like tree ensembles, linear models, neural networks etc. Unfortunately R packages that create such models are very inconsistent. Different tools use different interfaces to train, validate and use models. Two most popular frameworks for machine learning are `mlr` [@mlr] and `caret` [@caret]. Apart from them, dozens of R packages may be used for modeling.
This is why as the first step DALEX wraps-up the black-box model with meta-data that unifies model interfacing.
Below is a list of arguments required by the `explain()` function.
```
explain(model, data, y, predict_function,
link, ..., label)
```
* `model` - an R object, a model to be explained. *Required by*: all explainers.
* `data` - `data.frame` or `matrix`, a set that will be used for model validation. It should have the same structure as the dataset used for training. *Required by*: model performance, variable importance. *Default*: if possible, it should be extracted from the `model` object.
* `y` - a numeric vector with true labels paired with observations in `data`. *Required by*: variable importance. *Default*: no default.
* `predict_function` - a function that takes two arguments: model and data, and returns numeric vector with predictions. Predictions should be calculated in the same scale as the `y` labels. *Required by*: all explainers. *Default*: the generic `predict()` function.
* `link_function` - a transformation/link function that is applied to model predictions. *Required by*: variable effect. *Default*: the identity `I()` function.
* `label` - a character, a name of the model that will be used in plots. *Required by*: plots. *Default*: extracted from the `class` attribute of the `model`.
<p><span class="marginnote">Figure 2.2. The `explain()` function embeds `model`, validation `data` and `y` labels in a container. Model is accessed via universal interface specified by `predict_function()` and `link_function()`. The `label` field contains a unique name of the model.
</span>
<img src="images/explain_scheme.png"/></p>
The next section introduces use cases of regression. It will help to understand how to use the `explain()` function and for what purposes. Same functions may be used for binary classification.
## Use case: Regression. Apartment prices in Warsaw {#useCaseApartmetns}
To illustrate applications of DALEX to regression problems we will use an artificial dataset `apartments` available in the `DALEX` package. Our goal is to predict the price per square meter of an apartment based on selected features such as construction year, surface, floor, number of rooms, district. It should be noted that four of these variables are continuous while the fifth one is a categorical one. Prices are given in Euro.
```{r eval=FALSE}
library("DALEX")
head(apartments)
```
```{r hr_data, echo=FALSE}
library("DALEX")
knitr::kable(
head(apartments),
caption = 'Artificial dataset about apartment prices in Warsaw. The goal here is to predict the price per square meter for a new apartment.'
)
```
### Model 1: Linear regression
The first model is based on linear regression. It will be a simple model without any feature engineering.
```{r}
apartments_lm_model <- lm(m2.price ~ construction.year + surface + floor +
no.rooms + district, data = apartments)
summary(apartments_lm_model)
```
We have also another `apartmentsTest` dataset that can be used for validation of the model. Below is presented the mean square error calculated on the basis of validation data.
```{r}
predicted_mi2_lm <- predict(apartments_lm_model, apartmentsTest)
sqrt(mean((predicted_mi2_lm - apartmentsTest$m2.price)^2))
```
To create an explainer for the regression model it is enough to use `explain()` function with the `model`, `data` and `y` parameters. In the next chapter we will show how to use this explainer.
```{r}
explainer_lm <- explain(apartments_lm_model,
data = apartmentsTest[,2:6], y = apartmentsTest$m2.price)
```
### Model 2: Random forest
The second model is based on the random forest. It's a very elastic out-of-the-box model.
```{r}
library("randomForest")
set.seed(59)
apartments_rf_model <- randomForest(m2.price ~ construction.year + surface + floor +
no.rooms + district, data = apartments)
apartments_rf_model
```
Below you may see the mean square error calculated for `apartmentsTest` dataset.
```{r}
predicted_mi2_rf <- predict(apartments_rf_model, apartmentsTest)
sqrt(mean((predicted_mi2_rf - apartmentsTest$m2.price)^2))
```
We will create an explainer also for the random forest model. In the next chapter we will show how to use this explainer.
```{r}
explainer_rf <- explain(apartments_rf_model,
data = apartmentsTest[,2:6], y = apartmentsTest$m2.price)
```
**These two models have identical performance!** Which one should be used?