Skip to content

Commit

Permalink
exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ombhojane committed Oct 12, 2024
1 parent 7f4cffe commit 88b66e8
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 74 deletions.
62 changes: 43 additions & 19 deletions explainableai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import colorama
from colorama import Fore, Style

from explainableai.exceptions import ExplainableAIError

# Initialize colorama
colorama.init(autoreset=True)

Expand All @@ -27,6 +29,8 @@
from reportlab.platypus import PageBreak
import logging
from sklearn.model_selection import cross_val_score
from .model_interpretability import interpret_model
from .logging_config import logger


logger=logging.getLogger(__name__)
Expand Down Expand Up @@ -136,12 +140,12 @@ def _preprocess_data(self):
except Exception as e:
logger.error(f"Some error occur while updating...{str(e)}")

def analyze(self, batch_size=None, parallel=False):

def analyze(self, batch_size=None, parallel=False, instance_index=0):
logger.debug("Analysing...")
results = {}

logger.info("Evaluating model performance...")
# Evaluate model performance (batch processing if batch_size is provided)
if batch_size:
results['model_performance'] = self._process_in_batches(self._evaluate_model_in_batches, batch_size, parallel)
else:
Expand All @@ -153,37 +157,41 @@ def analyze(self, batch_size=None, parallel=False):

logger.info("Generating visualizations...")
self._generate_visualizations(self.feature_importance)

# Calculate SHAP values (batch processing if batch_size is provided)

logger.info("Calculating SHAP values...")
if batch_size:
results['shap_values'] = self._process_in_batches(self._calculate_shap_in_batches, batch_size, parallel)
shap_values = self._process_in_batches(self._calculate_shap_in_batches, batch_size, parallel)
results['shap_values'] = shap_values
else:
results['shap_values'] = calculate_shap_values(self.model, self.X, self.feature_names)

# Perform cross-validation (batch processing if batch_size is provided)
logger.info("Performing cross-validation...")
if batch_size:
results['cv_scores'] = self._process_in_batches(self._cross_validate_in_batches, batch_size, parallel)
cv_results = self._process_in_batches(self._cross_validate_in_batches, batch_size, parallel)
results['cv_scores'] = (np.mean(cv_results['mean_score']), np.mean(cv_results['std_score']))
else:
mean_score, std_score = cross_validate(self.model, self.X, self.y)
results['cv_scores'] = (mean_score, std_score)

logger.info("Model comparison results:")
results['model_comparison'] = self.model_comparison_results

logger.info("Performing model interpretation (SHAP and LIME)...")
try:
interpretation_results = interpret_model(self.model, self.X, self.feature_names, instance_index)
results.update(interpretation_results)
except ExplainableAIError as e:
logger.warning(f"Model interpretation failed: {str(e)}")
results['interpretation_error'] = str(e)

self._print_results(results)

logger.info("Generating LLM explanation...")
results['llm_explanation'] = get_llm_explanation(self.gemini_model, results)

# Generate XAI report after analysis
logger.info("Generating XAI report")
self.generate_report()

self.results = results
return results

def _process_in_batches(self, batch_func, batch_size, parallel=False):
results = []
num_batches = (len(self.X) + batch_size - 1) // batch_size # Calculate number of batches
Expand All @@ -209,7 +217,7 @@ def _process_in_batches(self, batch_func, batch_size, parallel=False):

# Aggregate results after batch processing
return self._aggregate_results(results)

# private helper functions
def _evaluate_model_in_batches(self, X_batch, y_batch):
return evaluate_model(self.model, X_batch, y_batch, self.is_classifier)
Expand Down Expand Up @@ -249,7 +257,6 @@ def _aggregate_results(self, results):

return aggregated_result


def generate_report(self, filename='xai_report.pdf'):
if self.results is None:
raise ValueError("No analysis results available. Please run analyze() first.")
Expand All @@ -272,9 +279,17 @@ def generate_report(self, filename='xai_report.pdf'):
for section, section_func in sections.items():
if input(f"Do you want {section} in xai_report? (y/n) ").lower() in ['y', 'yes']:
section_func(report)
self._generate_shap_lime_visualizations(report)

report.generate()

def _generate_shap_lime_visualizations(self, report):
report.add_heading("SHAP and LIME Visualizations", level=2)
report.add_image('shap_summary.png')
report.content.append(PageBreak())
report.add_image('lime_explanation.png')
report.content.append(PageBreak())

def _generate_model_comparison(self, report):
report.add_heading("Model Comparison", level=2)
model_comparison_data = [["Model", "CV Score", "Test Score"]] + [
Expand All @@ -286,7 +301,12 @@ def _generate_model_comparison(self, report):
def _generate_model_performance(self, report):
report.add_heading("Model Performance", level=2)
for metric, value in self.results['model_performance'].items():
report.add_paragraph(f"**{metric}:** {value:.4f}" if isinstance(value, (int, float, np.float64)) else f"**{metric}:**\n{value}")
if isinstance(value, np.ndarray):
report.add_paragraph(f"**{metric}:**\n{value}")
elif isinstance(value, (int, float, np.float64)):
report.add_paragraph(f"**{metric}:** {value:.4f}")
else:
report.add_paragraph(f"**{metric}:** {value}")

def _generate_feature_importance(self, report):
report.add_heading("Feature Importance", level=2)
Expand Down Expand Up @@ -401,10 +421,14 @@ def _print_results(self, results):
logger.info("- ROC Curve: roc_curve.png")
logger.info("- Precision-Recall Curve: precision_recall_curve.png")

if results['shap_values'] is not None:
logger.info("\nSHAP values calculated successfully. See 'shap_summary.png' for visualization.")
else:
logger.info("\nSHAP values calculation failed. Please check the console output for more details.")
if 'shap_plot_url' in results:
logger.info("\nSHAP summary plot saved as 'shap_summary.png'")
logger.info("SHAP plot URL (base64 encoded) available in results['shap_plot_url']")

if 'lime_plot_url' in results:
logger.info("\nLIME explanation plot saved as 'lime_explanation.png'")
logger.info("LIME plot URL (base64 encoded) available in results['lime_plot_url']")

except Exception as e:
logger.error(f"Error occur in printing results...{str(e)}")

Expand Down
3 changes: 3 additions & 0 deletions explainableai/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class ExplainableAIError(Exception):
"""Base exception class for ExplainableAI package"""
pass
22 changes: 22 additions & 0 deletions explainableai/logging_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import logging

def setup_logging():
logger = logging.getLogger('explainableai')
logger.setLevel(logging.DEBUG)

# Create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Add formatter to ch
ch.setFormatter(formatter)

# Add ch to logger
logger.addHandler(ch)

return logger

logger = setup_logging()
144 changes: 89 additions & 55 deletions explainableai/model_interpretability.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,116 @@
# model_interpretability.py
import shap
import lime
import lime.lime_tabular
import matplotlib.pyplot as plt
import numpy as np
import logging
import pandas as pd
import io
import base64
from .logging_config import logger
from .exceptions import ExplainableAIError

logger=logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

def calculate_shap_values(model, X):
logger.debug("Calculating values...")
def calculate_shap_values(model, X, feature_names):
logger.debug("Calculating SHAP values...")
try:
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
logger.info("Values caluated...")
X_df = pd.DataFrame(X, columns=feature_names)
if hasattr(model, 'predict_proba'):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_df)
else:
explainer = shap.Explainer(model, X_df)
shap_values = explainer(X_df)
logger.info("SHAP values calculated successfully.")
return shap_values
except Exception as e:
logger.error(f"Some error occurred in calculating values...{str(e)}")
logger.error(f"Error in calculate_shap_values: {str(e)}")
raise ExplainableAIError(f"Error in calculate_shap_values: {str(e)}")

def plot_shap_summary(shap_values, X):
logger.debug("Summary...")
def plot_shap_summary(shap_values, X, feature_names):
logger.debug("Plotting SHAP summary...")
try:
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X, plot_type="bar", show=False)
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X, plot_type="bar", feature_names=feature_names, show=False)
plt.tight_layout()

# Save plot to file
plt.savefig('shap_summary.png')
logger.info("SHAP summary plot saved as 'shap_summary.png'")

# Convert plot to base64 for display
img = io.BytesIO()
plt.savefig(img, format='png')
img.seek(0)
plot_url = base64.b64encode(img.getvalue()).decode()
plt.close()
except TypeError as e:
logger.error(f"Error in generating SHAP summary plot: {str(e)}")
logger.error("Attempting alternative SHAP visualization...")
try:
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values.values, X.values, feature_names=X.columns.tolist(), plot_type="bar", show=False)
plt.tight_layout()
plt.savefig('shap_summary.png')
plt.close()
except Exception as e2:
logger.error(f"Alternative SHAP visualization also failed: {str(e2)}")
logger.error("Skipping SHAP summary plot.")

return plot_url
except Exception as e:
logger.error(f"Error in plot_shap_summary: {str(e)}")
raise ExplainableAIError(f"Error in plot_shap_summary: {str(e)}")

def get_lime_explanation(model, X, instance, feature_names):
logger.debug("Explaining model...")
logger.debug("Generating LIME explanation...")
try:
explainer = lime.lime_tabular.LimeTabularExplainer(
X,
feature_names=feature_names,
class_names=['Negative', 'Positive'],
mode='classification'
mode='classification' if hasattr(model, 'predict_proba') else 'regression'
)
exp = explainer.explain_instance(
instance,
model.predict_proba if hasattr(model, 'predict_proba') else model.predict
)
exp = explainer.explain_instance(instance, model.predict_proba)
logger.info("Model explained...")
logger.info("LIME explanation generated successfully.")
return exp
except Exception as e:
logger.error(f"Some error occurred in explaining model...{str(e)}")
logger.error(f"Error in get_lime_explanation: {str(e)}")
raise ExplainableAIError(f"Error in get_lime_explanation: {str(e)}")

def plot_lime_explanation(exp):
exp.as_pyplot_figure()
plt.tight_layout()
plt.savefig('lime_explanation.png')
plt.close()
logger.debug("Plotting LIME explanation...")
try:
plt.figure(figsize=(12, 8))
exp.as_pyplot_figure()
plt.tight_layout()

# Save plot to file
plt.savefig('lime_explanation.png')
logger.info("LIME explanation plot saved as 'lime_explanation.png'")

# Convert plot to base64 for display
img = io.BytesIO()
plt.savefig(img, format='png')
img.seek(0)
plot_url = base64.b64encode(img.getvalue()).decode()
plt.close()

return plot_url
except Exception as e:
logger.error(f"Error in plot_lime_explanation: {str(e)}")
raise ExplainableAIError(f"Error in plot_lime_explanation: {str(e)}")

def plot_ice_curve(model, X, feature, num_ice_lines=50):
ice_data = X.copy()
feature_values = np.linspace(X[feature].min(), X[feature].max(), num=100)

plt.figure(figsize=(10, 6))
for _ in range(num_ice_lines):
ice_instance = ice_data.sample(n=1, replace=True)
predictions = []
for value in feature_values:
ice_instance[feature] = value
predictions.append(model.predict_proba(ice_instance)[0][1])
plt.plot(feature_values, predictions, color='blue', alpha=0.1)

plt.xlabel(feature)
plt.ylabel('Predicted Probability')
plt.title(f'ICE Plot for {feature}')
plt.tight_layout()
plt.savefig(f'ice_plot_{feature}.png')
plt.close()
def interpret_model(model, X, feature_names, instance_index=0):
logger.info("Starting model interpretation...")
try:
# SHAP analysis
shap_values = calculate_shap_values(model, X, feature_names)
shap_plot_url = plot_shap_summary(shap_values, X, feature_names)

# LIME analysis
instance = X[instance_index]
lime_exp = get_lime_explanation(model, X, instance, feature_names)
lime_plot_url = plot_lime_explanation(lime_exp)

interpretation_results = {
"shap_values": shap_values,
"shap_plot_url": shap_plot_url,
"lime_explanation": lime_exp,
"lime_plot_url": lime_plot_url
}

logger.info("Model interpretation completed successfully.")
return interpretation_results
except Exception as e:
logger.error(f"Error in interpret_model: {str(e)}")
raise ExplainableAIError(f"Error in interpret_model: {str(e)}")

0 comments on commit 88b66e8

Please sign in to comment.