Skip to content

Commit

Permalink
added SVM to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ombhojane committed Sep 26, 2024
1 parent 709d835 commit 16c7906
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
8 changes: 6 additions & 2 deletions explainableai/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
import matplotlib.pyplot as plt
import numpy as np

def compare_models(X_train, y_train, X_test, y_test):
models = {
def get_default_models():
return {
'Logistic Regression': LogisticRegression(max_iter=1000),
'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
'SVM': SVC(probability=True, random_state=42),
'XGBoost': XGBClassifier(n_estimators=100, random_state=42),
'Neural Network': MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000, random_state=42)
}

def compare_models(X_train, y_train, X_test, y_test, models=None):
if models is None:
models = get_default_models()

results = {}
for name, model in models.items():
Expand Down
Binary file not shown.
6 changes: 4 additions & 2 deletions tests/test_xai_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from explainableai import XAIWrapper
import os
Expand All @@ -23,7 +24,8 @@ def sample_models():
'Random Forest': RandomForestClassifier(n_estimators=10, random_state=42),
'Logistic Regression': LogisticRegression(max_iter=1000),
'XGBoost': XGBClassifier(n_estimators=10, random_state=42),
'Neural Network': MLPClassifier(hidden_layer_sizes=(10,), max_iter=1000, random_state=42)
'Neural Network': MLPClassifier(hidden_layer_sizes=(10,), max_iter=1000, random_state=42),
'SVM': SVC(probability=True, random_state=42)
}

def test_xai_wrapper_initialization(sample_data, sample_models):
Expand All @@ -44,7 +46,7 @@ def test_xai_wrapper_fit(sample_data, sample_models):
assert hasattr(xai.model, 'predict')
assert hasattr(xai.model, 'predict_proba')

@pytest.mark.parametrize("model_name", ['Random Forest', 'Logistic Regression', 'XGBoost', 'Neural Network'])
@pytest.mark.parametrize("model_name", ['Random Forest', 'Logistic Regression', 'XGBoost', 'Neural Network', 'SVM'])
def test_xai_wrapper_analyze_with_different_models(sample_data, sample_models, model_name):
X, y = sample_data
models = {model_name: sample_models[model_name]}
Expand Down

0 comments on commit 16c7906

Please sign in to comment.