-
-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
157 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
class ExplainableAIError(Exception): | ||
"""Base exception class for ExplainableAI package""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import logging | ||
|
||
def setup_logging(): | ||
logger = logging.getLogger('explainableai') | ||
logger.setLevel(logging.DEBUG) | ||
|
||
# Create console handler and set level to debug | ||
ch = logging.StreamHandler() | ||
ch.setLevel(logging.DEBUG) | ||
|
||
# Create formatter | ||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
|
||
# Add formatter to ch | ||
ch.setFormatter(formatter) | ||
|
||
# Add ch to logger | ||
logger.addHandler(ch) | ||
|
||
return logger | ||
|
||
logger = setup_logging() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,82 +1,116 @@ | ||
# model_interpretability.py | ||
import shap | ||
import lime | ||
import lime.lime_tabular | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import logging | ||
import pandas as pd | ||
import io | ||
import base64 | ||
from .logging_config import logger | ||
from .exceptions import ExplainableAIError | ||
|
||
logger=logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
def calculate_shap_values(model, X): | ||
logger.debug("Calculating values...") | ||
def calculate_shap_values(model, X, feature_names): | ||
logger.debug("Calculating SHAP values...") | ||
try: | ||
explainer = shap.Explainer(model, X) | ||
shap_values = explainer(X) | ||
logger.info("Values caluated...") | ||
X_df = pd.DataFrame(X, columns=feature_names) | ||
if hasattr(model, 'predict_proba'): | ||
explainer = shap.TreeExplainer(model) | ||
shap_values = explainer.shap_values(X_df) | ||
else: | ||
explainer = shap.Explainer(model, X_df) | ||
shap_values = explainer(X_df) | ||
logger.info("SHAP values calculated successfully.") | ||
return shap_values | ||
except Exception as e: | ||
logger.error(f"Some error occurred in calculating values...{str(e)}") | ||
logger.error(f"Error in calculate_shap_values: {str(e)}") | ||
raise ExplainableAIError(f"Error in calculate_shap_values: {str(e)}") | ||
|
||
def plot_shap_summary(shap_values, X): | ||
logger.debug("Summary...") | ||
def plot_shap_summary(shap_values, X, feature_names): | ||
logger.debug("Plotting SHAP summary...") | ||
try: | ||
plt.figure(figsize=(10, 8)) | ||
shap.summary_plot(shap_values, X, plot_type="bar", show=False) | ||
plt.figure(figsize=(12, 8)) | ||
shap.summary_plot(shap_values, X, plot_type="bar", feature_names=feature_names, show=False) | ||
plt.tight_layout() | ||
|
||
# Save plot to file | ||
plt.savefig('shap_summary.png') | ||
logger.info("SHAP summary plot saved as 'shap_summary.png'") | ||
|
||
# Convert plot to base64 for display | ||
img = io.BytesIO() | ||
plt.savefig(img, format='png') | ||
img.seek(0) | ||
plot_url = base64.b64encode(img.getvalue()).decode() | ||
plt.close() | ||
except TypeError as e: | ||
logger.error(f"Error in generating SHAP summary plot: {str(e)}") | ||
logger.error("Attempting alternative SHAP visualization...") | ||
try: | ||
plt.figure(figsize=(10, 8)) | ||
shap.summary_plot(shap_values.values, X.values, feature_names=X.columns.tolist(), plot_type="bar", show=False) | ||
plt.tight_layout() | ||
plt.savefig('shap_summary.png') | ||
plt.close() | ||
except Exception as e2: | ||
logger.error(f"Alternative SHAP visualization also failed: {str(e2)}") | ||
logger.error("Skipping SHAP summary plot.") | ||
|
||
return plot_url | ||
except Exception as e: | ||
logger.error(f"Error in plot_shap_summary: {str(e)}") | ||
raise ExplainableAIError(f"Error in plot_shap_summary: {str(e)}") | ||
|
||
def get_lime_explanation(model, X, instance, feature_names): | ||
logger.debug("Explaining model...") | ||
logger.debug("Generating LIME explanation...") | ||
try: | ||
explainer = lime.lime_tabular.LimeTabularExplainer( | ||
X, | ||
feature_names=feature_names, | ||
class_names=['Negative', 'Positive'], | ||
mode='classification' | ||
mode='classification' if hasattr(model, 'predict_proba') else 'regression' | ||
) | ||
exp = explainer.explain_instance( | ||
instance, | ||
model.predict_proba if hasattr(model, 'predict_proba') else model.predict | ||
) | ||
exp = explainer.explain_instance(instance, model.predict_proba) | ||
logger.info("Model explained...") | ||
logger.info("LIME explanation generated successfully.") | ||
return exp | ||
except Exception as e: | ||
logger.error(f"Some error occurred in explaining model...{str(e)}") | ||
logger.error(f"Error in get_lime_explanation: {str(e)}") | ||
raise ExplainableAIError(f"Error in get_lime_explanation: {str(e)}") | ||
|
||
def plot_lime_explanation(exp): | ||
exp.as_pyplot_figure() | ||
plt.tight_layout() | ||
plt.savefig('lime_explanation.png') | ||
plt.close() | ||
logger.debug("Plotting LIME explanation...") | ||
try: | ||
plt.figure(figsize=(12, 8)) | ||
exp.as_pyplot_figure() | ||
plt.tight_layout() | ||
|
||
# Save plot to file | ||
plt.savefig('lime_explanation.png') | ||
logger.info("LIME explanation plot saved as 'lime_explanation.png'") | ||
|
||
# Convert plot to base64 for display | ||
img = io.BytesIO() | ||
plt.savefig(img, format='png') | ||
img.seek(0) | ||
plot_url = base64.b64encode(img.getvalue()).decode() | ||
plt.close() | ||
|
||
return plot_url | ||
except Exception as e: | ||
logger.error(f"Error in plot_lime_explanation: {str(e)}") | ||
raise ExplainableAIError(f"Error in plot_lime_explanation: {str(e)}") | ||
|
||
def plot_ice_curve(model, X, feature, num_ice_lines=50): | ||
ice_data = X.copy() | ||
feature_values = np.linspace(X[feature].min(), X[feature].max(), num=100) | ||
|
||
plt.figure(figsize=(10, 6)) | ||
for _ in range(num_ice_lines): | ||
ice_instance = ice_data.sample(n=1, replace=True) | ||
predictions = [] | ||
for value in feature_values: | ||
ice_instance[feature] = value | ||
predictions.append(model.predict_proba(ice_instance)[0][1]) | ||
plt.plot(feature_values, predictions, color='blue', alpha=0.1) | ||
|
||
plt.xlabel(feature) | ||
plt.ylabel('Predicted Probability') | ||
plt.title(f'ICE Plot for {feature}') | ||
plt.tight_layout() | ||
plt.savefig(f'ice_plot_{feature}.png') | ||
plt.close() | ||
def interpret_model(model, X, feature_names, instance_index=0): | ||
logger.info("Starting model interpretation...") | ||
try: | ||
# SHAP analysis | ||
shap_values = calculate_shap_values(model, X, feature_names) | ||
shap_plot_url = plot_shap_summary(shap_values, X, feature_names) | ||
|
||
# LIME analysis | ||
instance = X[instance_index] | ||
lime_exp = get_lime_explanation(model, X, instance, feature_names) | ||
lime_plot_url = plot_lime_explanation(lime_exp) | ||
|
||
interpretation_results = { | ||
"shap_values": shap_values, | ||
"shap_plot_url": shap_plot_url, | ||
"lime_explanation": lime_exp, | ||
"lime_plot_url": lime_plot_url | ||
} | ||
|
||
logger.info("Model interpretation completed successfully.") | ||
return interpretation_results | ||
except Exception as e: | ||
logger.error(f"Error in interpret_model: {str(e)}") | ||
raise ExplainableAIError(f"Error in interpret_model: {str(e)}") |