Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
Signed-off-by: Gaurav Gupta <[email protected]>
  • Loading branch information
gaugup committed Oct 27, 2023
1 parent b63ab59 commit d4e5958
Showing 1 changed file with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os

import numpy as np
import pytest

from rai_test_utils.models.lightgbm import create_lightgbm_classifier
Expand All @@ -21,10 +22,15 @@ class TestCounterfactualAdvancedFeatures(object):

@pytest.mark.parametrize('vary_all_features', [True, False])
@pytest.mark.parametrize('feature_importance', [True, False])
@pytest.mark.parametrize('encode_target_as_strings', [True, False])
def test_counterfactual_vary_features(
self, vary_all_features, feature_importance):
self, vary_all_features, feature_importance,
encode_target_as_strings):
X_train, X_test, y_train, y_test, feature_names, _ = \
create_iris_data()
if encode_target_as_strings:
y_train = y_train.astype(str)
y_test = y_test.astype(str)

model = create_lightgbm_classifier(X_train, y_train)
X_train['target'] = y_train
Expand All @@ -50,6 +56,26 @@ def test_counterfactual_vary_features(

cf_obj = rai_insights.counterfactual.get()[0]
assert cf_obj is not None
for index in range(0, len(cf_obj.cf_examples_list)):
if encode_target_as_strings:
assert isinstance(
cf_obj.cf_examples_list[
index].test_instance_df['target'].values[0], str)
else:
assert isinstance(
cf_obj.cf_examples_list[
index].test_instance_df['target'].values[0], np.int32)
assert cf_obj.cf_examples_list[
index].test_instance_df['target'].values[0] in set(y_train)

cf_target_array = cf_obj.cf_examples_list[0].final_cfs_df[
'target'].values
for inner_index in range(0, 10):
if encode_target_as_strings:
assert isinstance(cf_target_array[inner_index], str)
else:
assert isinstance(cf_target_array[inner_index], np.int32)
assert cf_target_array[inner_index] in set(y_train)

@pytest.mark.parametrize('feature_importance', [True, False])
def test_counterfactual_permitted_range(self, feature_importance):
Expand Down

0 comments on commit d4e5958

Please sign in to comment.