diff --git a/tests/test_simba.py b/tests/test_simba.py index 1fcab61..3b3e447 100644 --- a/tests/test_simba.py +++ b/tests/test_simba.py @@ -13,7 +13,7 @@ class TestSimBA(unittest.TestCase): def test_simba(self): # Load Image [0.0, 1.0] - x = np.asarray(Image.open("tests/dog.jpg").resize((32, 32))) / 255.0 + x = np.asarray(Image.open("tests/cat.jpg").resize((32, 32))) / 255.0 # Initialize API Model model = VGG16Cifar10("https://api.wuhanstudio.uk" + "/vgg16_cifar10") @@ -21,7 +21,7 @@ def test_simba(self): # Get Preditction y_pred = model.predict(np.array([x]))[0] - assert (np.argmax(y_pred) == 5) + assert (np.argmax(y_pred) == 3) # SimBA Attack simba = SimBA(model)