From 5921d3627c2651d8b55c421c2f0a0af8ae5f942f Mon Sep 17 00:00:00 2001 From: Rawan Mahdi <111605471+rawanmahdi@users.noreply.github.com> Date: Sat, 15 Jul 2023 13:30:51 -0400 Subject: [PATCH] Added keras support (#345) --- python/README.md | 2 +- python/examples/keras_classifier.py | 52 +++++++++++++++++++++++++++++ python/shaprpy/explain.py | 16 +++++++-- 3 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 python/examples/keras_classifier.py diff --git a/python/README.md b/python/README.md index 740fa3a85..6b5b1238a 100644 --- a/python/README.md +++ b/python/README.md @@ -56,7 +56,7 @@ df_shapley, pred_explain, internal, timing = explain( print(df_shapley) ``` -`shaprpy` knows how to explain predictions from models from `sklearn` and `xgboost`. +`shaprpy` knows how to explain predictions from models from `sklearn`, `keras` and `xgboost`. For other models, one can provide a custom `predict_model` function (and optionally a custom `get_model_specs`) to `shaprpy.explain`. See `/examples` for runnable examples, including an example of a custom PyTorch model. diff --git a/python/examples/keras_classifier.py b/python/examples/keras_classifier.py new file mode 100644 index 000000000..73617443e --- /dev/null +++ b/python/examples/keras_classifier.py @@ -0,0 +1,52 @@ +from keras import Sequential +from keras import layers +from keras import utils +from shaprpy import explain +from shaprpy.datasets import load_binary_iris + + +dfx_train, dfx_test, dfy_train, dfy_test = load_binary_iris() + +utils.set_random_seed(1) + +## Build model +model = Sequential([ + layers.Dense(units=8, activation='relu'), + layers.Dense(units=16, activation='relu'), + layers.Dense(units=8, activation='relu'), + layers.Dense(units=1, activation='sigmoid') +]) +model.compile(optimizer="adam", + loss ="binary_crossentropy", + metrics=["accuracy"]) + +## Fit Model +model.fit(dfx_train, dfy_train, + epochs=10, + validation_data=(dfx_test, dfy_test)) +## Shapr +df_shapley, pred_explain, internal, timing = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'empirical', + prediction_zero = dfy_train.mean().item(), +) +print(df_shapley) + +""" + none sepal length (cm) sepal width (cm) petal length (cm) \ +1 0.494737 0.042263 0.037911 0.059232 +2 0.494737 0.034217 0.029183 0.045027 +3 0.494737 0.045776 0.031752 0.058278 +4 0.494737 0.014977 0.032691 0.014280 +5 0.494737 0.022742 0.025851 0.027427 + + petal width (cm) +1 0.058412 +2 0.053639 +3 0.070650 +4 0.018697 +5 0.026814 + + """ \ No newline at end of file diff --git a/python/shaprpy/explain.py b/python/shaprpy/explain.py index 1d77cb7d1..0b214a3f5 100644 --- a/python/shaprpy/explain.py +++ b/python/shaprpy/explain.py @@ -45,7 +45,7 @@ def explain( Parameters ---------- model: The model whose predictions we want to explain. - `shaprpy` natively supports `sklearn` and `xgboost` models. + `shaprpy` natively supports `sklearn`, `xgboost` and `keras` models. Unsupported models can still be explained by passing `predict_model` and (optionally) `get_model_specs`. x_explain: Contains the the features, whose predictions ought to be explained. x_train: Contains the data used to estimate the (conditional) distributions for the features @@ -79,7 +79,8 @@ def explain( and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array. `None` (the default) uses functions specified internally. Can also be used to override the default function for natively supported model classes. - get_model_specs: An optional function for checking model/data consistency when `model` is not natively supported. + get_model_specs: An optional function for checking model/data consistency when `model` is not natively supported. + This method has yet to be implemented for keras models. The function takes `model` as argument and provides a `dict with 3 elements: - labels: list[str] with the names of each feature. - classes: list[str] with the classes of each features. @@ -305,6 +306,17 @@ def prebuilt_predict_model(model): return lambda m, x: m.predict(xgb.DMatrix(x)) except: pass + + # Look for keras + try: + from keras.models import Model + if isinstance(model, Model): + def predict_fn(m,x): + pred = m.predict(x) + return pred.reshape(pred.shape[0],) + return predict_fn + except: + pass return None