Skip to content

Commit

Permalink
Refine permutation workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Jan 14, 2025
1 parent 06fac92 commit 6010004
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions baybe/insights/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,19 @@ def explain(self, data: pd.DataFrame | None = None, /) -> shap.Explanation:
else:
explanations = self.explainer(df_aligned)

# Permute explanation object data according to input column order.
# Do not do this for the base_values as it can be a scalar.
# Permute explanation object data according to input column order
# (`base_values` can be a scalar or vector)
# TODO: https://github.com/shap/shap/issues/3958
idx = self.background_data.columns.get_indexer(data.columns)
for attr in ["values", "data"]:
setattr(explanations, attr, getattr(explanations, attr)[:, idx])
for attr in ["values", "data", "base_values"]:
try:
setattr(explanations, attr, getattr(explanations, attr)[:, idx])
except IndexError as ex:
if not (
isinstance(explanations.base_values, float)
or explanations.base_values.shape[1] == 1
):
raise TypeError("Unexpected explanation format.") from ex
explanations.feature_names = [explanations.feature_names[i] for i in idx]

# Reduce dimensionality of explanations to 2D in case
Expand Down

0 comments on commit 6010004

Please sign in to comment.