Skip to content

Commit

Permalink
Added keras support (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
rawanmahdi authored Jul 15, 2023
1 parent 111053f commit 5921d36
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ df_shapley, pred_explain, internal, timing = explain(
print(df_shapley)
```

`shaprpy` knows how to explain predictions from models from `sklearn` and `xgboost`.
`shaprpy` knows how to explain predictions from models from `sklearn`, `keras` and `xgboost`.
For other models, one can provide a custom `predict_model` function (and optionally a custom `get_model_specs`) to `shaprpy.explain`.

See `/examples` for runnable examples, including an example of a custom PyTorch model.
52 changes: 52 additions & 0 deletions python/examples/keras_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from keras import Sequential
from keras import layers
from keras import utils
from shaprpy import explain
from shaprpy.datasets import load_binary_iris


dfx_train, dfx_test, dfy_train, dfy_test = load_binary_iris()

utils.set_random_seed(1)

## Build model
model = Sequential([
layers.Dense(units=8, activation='relu'),
layers.Dense(units=16, activation='relu'),
layers.Dense(units=8, activation='relu'),
layers.Dense(units=1, activation='sigmoid')
])
model.compile(optimizer="adam",
loss ="binary_crossentropy",
metrics=["accuracy"])

## Fit Model
model.fit(dfx_train, dfy_train,
epochs=10,
validation_data=(dfx_test, dfy_test))
## Shapr
df_shapley, pred_explain, internal, timing = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
prediction_zero = dfy_train.mean().item(),
)
print(df_shapley)

"""
none sepal length (cm) sepal width (cm) petal length (cm) \
1 0.494737 0.042263 0.037911 0.059232
2 0.494737 0.034217 0.029183 0.045027
3 0.494737 0.045776 0.031752 0.058278
4 0.494737 0.014977 0.032691 0.014280
5 0.494737 0.022742 0.025851 0.027427
petal width (cm)
1 0.058412
2 0.053639
3 0.070650
4 0.018697
5 0.026814
"""
16 changes: 14 additions & 2 deletions python/shaprpy/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def explain(
Parameters
----------
model: The model whose predictions we want to explain.
`shaprpy` natively supports `sklearn` and `xgboost` models.
`shaprpy` natively supports `sklearn`, `xgboost` and `keras` models.
Unsupported models can still be explained by passing `predict_model` and (optionally) `get_model_specs`.
x_explain: Contains the the features, whose predictions ought to be explained.
x_train: Contains the data used to estimate the (conditional) distributions for the features
Expand Down Expand Up @@ -79,7 +79,8 @@ def explain(
and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array.
`None` (the default) uses functions specified internally.
Can also be used to override the default function for natively supported model classes.
get_model_specs: An optional function for checking model/data consistency when `model` is not natively supported.
get_model_specs: An optional function for checking model/data consistency when `model` is not natively supported.
This method has yet to be implemented for keras models.
The function takes `model` as argument and provides a `dict with 3 elements:
- labels: list[str] with the names of each feature.
- classes: list[str] with the classes of each features.
Expand Down Expand Up @@ -305,6 +306,17 @@ def prebuilt_predict_model(model):
return lambda m, x: m.predict(xgb.DMatrix(x))
except:
pass

# Look for keras
try:
from keras.models import Model
if isinstance(model, Model):
def predict_fn(m,x):
pred = m.predict(x)
return pred.reshape(pred.shape[0],)
return predict_fn
except:
pass

return None

Expand Down

0 comments on commit 5921d36

Please sign in to comment.