diff --git a/exercise.ipynb b/exercise.ipynb index 1f76626..4a0a4a5 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2cec6e2e", + "id": "1b3bf19e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "c9faff54", + "id": "067e738b", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "8b2714f2", + "id": "b913b71a", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "196887ff", + "id": "98283047", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "dc029ece", + "id": "1afbe1b3", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f76a8a6", + "id": "5bef908b", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "9ec21db3", + "id": "3e27811b", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "06793f28", + "id": "c554ff65", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bc9ae08", + "id": "0f8465d5", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "e0d3e182", + "id": "9460ef99", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bd2cb1e", + "id": "e35f652a", "metadata": { "lines_to_next_cell": 0 }, @@ -195,7 +195,7 @@ }, { "cell_type": "markdown", - "id": "88a42af2", + "id": "714d0118", "metadata": { "lines_to_next_cell": 0 }, @@ -212,7 +212,7 @@ }, { "cell_type": "markdown", - "id": "72fbabcd", + "id": "2b3080ec", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -222,7 +222,7 @@ }, { "cell_type": "markdown", - "id": "5a41df80", + "id": "42bf258c", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -235,7 +235,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c7615db3", + "id": "9dc58dab", "metadata": { "tags": [] }, @@ -253,7 +253,7 @@ }, { "cell_type": "markdown", - "id": "f8bc4307", + "id": "50185c99", "metadata": { "tags": [] }, @@ -269,7 +269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36075e2c", + "id": "e960ab8b", "metadata": { "tags": [ "task" @@ -290,7 +290,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ebbe326", + "id": "eed52045", "metadata": { "tags": [] }, @@ -303,7 +303,7 @@ }, { "cell_type": "markdown", - "id": "2de7d44e", + "id": "be5a007b", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -315,7 +315,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c484acb2", + "id": "5c0780f6", "metadata": { "tags": [] }, @@ -343,7 +343,7 @@ { "cell_type": "code", "execution_count": null, - "id": "898595ea", + "id": "30041689", "metadata": { "tags": [] }, @@ -356,7 +356,7 @@ }, { "cell_type": "markdown", - "id": "5b66372f", + "id": "c0d46960", "metadata": { "lines_to_next_cell": 2 }, @@ -370,7 +370,7 @@ }, { "cell_type": "markdown", - "id": "f26f6f42", + "id": "477cd87c", "metadata": { "lines_to_next_cell": 0 }, @@ -383,7 +383,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69cf5669", + "id": "ed247a59", "metadata": {}, "outputs": [], "source": [ @@ -408,7 +408,7 @@ }, { "cell_type": "markdown", - "id": "9642568b", + "id": "746de08d", "metadata": { "lines_to_next_cell": 0 }, @@ -422,7 +422,7 @@ }, { "cell_type": "markdown", - "id": "2891c43d", + "id": "8c5ea46b", "metadata": {}, "source": [ "\n", @@ -448,7 +448,7 @@ }, { "cell_type": "markdown", - "id": "357966b5", + "id": "eaacace4", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -460,7 +460,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8e4e92d0", + "id": "bbbe47bb", "metadata": { "tags": [ "task" @@ -482,7 +482,7 @@ }, { "cell_type": "markdown", - "id": "f9b5fba2", + "id": "bb810907", "metadata": { "tags": [] }, @@ -496,7 +496,7 @@ { "cell_type": "code", "execution_count": null, - "id": "58262250", + "id": "5b379c73", "metadata": { "tags": [ "task" @@ -520,7 +520,7 @@ }, { "cell_type": "markdown", - "id": "ccedb79c", + "id": "342b9336", "metadata": { "tags": [] }, @@ -536,7 +536,7 @@ }, { "cell_type": "markdown", - "id": "68893afb", + "id": "96d713fd", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -550,7 +550,7 @@ }, { "cell_type": "markdown", - "id": "849fa319", + "id": "f6118ade", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -570,7 +570,7 @@ }, { "cell_type": "markdown", - "id": "876bdd3b", + "id": "12781ee6", "metadata": { "lines_to_next_cell": 0 }, @@ -598,7 +598,7 @@ }, { "cell_type": "markdown", - "id": "5a1c2a34", + "id": "6dd2e900", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -621,7 +621,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85aa4a15", + "id": "e489a340", "metadata": {}, "outputs": [], "source": [ @@ -653,7 +653,7 @@ }, { "cell_type": "markdown", - "id": "2b95f3e8", + "id": "d77aa8d1", "metadata": { "lines_to_next_cell": 0 }, @@ -668,7 +668,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ac1205e5", + "id": "974e17d8", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -689,7 +689,7 @@ }, { "cell_type": "markdown", - "id": "9c976d43", + "id": "50f5d295", "metadata": { "tags": [] }, @@ -704,7 +704,7 @@ }, { "cell_type": "markdown", - "id": "881575de", + "id": "f4dd16b3", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -721,7 +721,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e43f27f1", + "id": "ccf31a3c", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "d29d8d58", + "id": "49ff9e3a", "metadata": { "lines_to_next_cell": 0 }, @@ -746,7 +746,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f51b15fc", + "id": "b0f0fb98", "metadata": {}, "outputs": [], "source": [ @@ -756,7 +756,7 @@ }, { "cell_type": "markdown", - "id": "73814dae", + "id": "23b21a32", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -774,7 +774,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5fc36682", + "id": "29678f3b", "metadata": { "lines_to_next_cell": 0 }, @@ -786,7 +786,7 @@ }, { "cell_type": "markdown", - "id": "e07edb85", + "id": "6c069dc6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -805,7 +805,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5082b8ab", + "id": "321515da", "metadata": {}, "outputs": [], "source": [ @@ -814,7 +814,7 @@ }, { "cell_type": "markdown", - "id": "43ef8a2d", + "id": "060fd784", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "468fe88f", + "id": "8c85a285", "metadata": {}, "outputs": [], "source": [ @@ -839,7 +839,7 @@ }, { "cell_type": "markdown", - "id": "6dd0f856", + "id": "511db0c6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -851,7 +851,7 @@ { "cell_type": "code", "execution_count": null, - "id": "89da32d6", + "id": "6c3dfe27", "metadata": {}, "outputs": [], "source": [ @@ -864,7 +864,7 @@ }, { "cell_type": "markdown", - "id": "2b894eac", + "id": "aca01927", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -878,7 +878,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b25e284e", + "id": "c0bb0943", "metadata": {}, "outputs": [], "source": [ @@ -890,7 +890,7 @@ }, { "cell_type": "markdown", - "id": "72125f93", + "id": "1c17e308", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -910,7 +910,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41cf9baf", + "id": "c378d7f9", "metadata": {}, "outputs": [], "source": [ @@ -934,7 +934,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5b1fb2e", + "id": "07f4fecf", "metadata": { "lines_to_next_cell": 2 }, @@ -946,7 +946,7 @@ }, { "cell_type": "markdown", - "id": "421a4724", + "id": "452728c7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -968,7 +968,7 @@ }, { "cell_type": "markdown", - "id": "d375c071", + "id": "9a73a3eb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -980,7 +980,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f782b7f0", + "id": "01ddb9f7", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1091,7 +1091,7 @@ }, { "cell_type": "markdown", - "id": "bf2700e1", + "id": "2b1901cd", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1103,7 +1103,7 @@ { "cell_type": "code", "execution_count": null, - "id": "93af3047", + "id": "32eebf00", "metadata": {}, "outputs": [], "source": [ @@ -1119,7 +1119,7 @@ }, { "cell_type": "markdown", - "id": "65452041", + "id": "fd98dc27", "metadata": { "tags": [] }, @@ -1134,7 +1134,7 @@ }, { "cell_type": "markdown", - "id": "db668256", + "id": "c74e6127", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1146,7 +1146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f76c5e65", + "id": "d38a9dfd", "metadata": {}, "outputs": [], "source": [ @@ -1168,7 +1168,7 @@ }, { "cell_type": "markdown", - "id": "573bb9d9", + "id": "bd940bab", "metadata": { "tags": [] }, @@ -1184,7 +1184,7 @@ }, { "cell_type": "markdown", - "id": "26dfd88b", + "id": "24e48cf3", "metadata": { "tags": [] }, @@ -1194,7 +1194,7 @@ }, { "cell_type": "markdown", - "id": "f48126ab", + "id": "1f56a54f", "metadata": { "tags": [] }, @@ -1211,7 +1211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4305090b", + "id": "8f572d41", "metadata": { "title": "Loading the test dataset" }, @@ -1231,7 +1231,7 @@ }, { "cell_type": "markdown", - "id": "154db796", + "id": "dc260633", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1243,7 +1243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebbcd12d", + "id": "0429a766", "metadata": {}, "outputs": [], "source": [ @@ -1256,7 +1256,7 @@ }, { "cell_type": "markdown", - "id": "61d5d4bc", + "id": "80f4b326", "metadata": { "lines_to_next_cell": 0 }, @@ -1266,7 +1266,7 @@ }, { "cell_type": "markdown", - "id": "7ae07350", + "id": "8443a95a", "metadata": { "lines_to_next_cell": 0 }, @@ -1284,7 +1284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bca04b2a", + "id": "012af37d", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1320,7 +1320,7 @@ }, { "cell_type": "markdown", - "id": "bebb99c9", + "id": "9a2a6a90", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1332,7 +1332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c87df0c0", + "id": "ca2d44e7", "metadata": {}, "outputs": [], "source": [ @@ -1345,7 +1345,7 @@ }, { "cell_type": "markdown", - "id": "2fbc38e9", + "id": "63966675", "metadata": { "tags": [] }, @@ -1360,7 +1360,7 @@ }, { "cell_type": "markdown", - "id": "baf8c83a", + "id": "413852c4", "metadata": { "tags": [] }, @@ -1371,7 +1371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2465c04b", + "id": "65d61fc4", "metadata": {}, "outputs": [], "source": [ @@ -1385,7 +1385,7 @@ }, { "cell_type": "markdown", - "id": "35613805", + "id": "b926cdc6", "metadata": { "tags": [] }, @@ -1400,7 +1400,7 @@ }, { "cell_type": "markdown", - "id": "321fcad6", + "id": "4aaf7349", "metadata": { "lines_to_next_cell": 0 }, @@ -1415,7 +1415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c45c814", + "id": "ac005791", "metadata": {}, "outputs": [], "source": [ @@ -1436,7 +1436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "584f13f4", + "id": "8f3faac4", "metadata": { "title": "Another visualization function" }, @@ -1465,7 +1465,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df33dd4a", + "id": "956658ca", "metadata": { "lines_to_next_cell": 0 }, @@ -1473,7 +1473,7 @@ "source": [ "for idx in range(batch_size):\n", " print(\"Source class:\", y[idx].item())\n", - " print(\"Target class:\", 0)\n", + " print(\"Target class:\", target_class)\n", " visualize_color_attribution_and_counterfactual(\n", " attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()\n", " )" @@ -1481,7 +1481,7 @@ }, { "cell_type": "markdown", - "id": "f2bdfa5d", + "id": "8df09e0c", "metadata": { "lines_to_next_cell": 0 }, @@ -1497,11 +1497,21 @@ }, { "cell_type": "markdown", - "id": "da87745e", + "id": "59557649", "metadata": { "lines_to_next_cell": 0 }, "source": [ + "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", + "\n", + "Here are two examples of image-counterfactual-attribution triplets.\n", + "You'll notice that they are *very* similar in every way! But one set is different classes, and one set is the same class!\n", + "\n", + "![same_class](assets/same_class.png)\n", + "![diff_class](assets/diff_class.png)\n", + "\n", + "We are missing a crucial step of the explanation pipeline: a quantification of how the class changes over the interpolation. \n", + "\n", "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n", "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n", "Here we have a much simpler task so we have some advantages:\n", @@ -1512,7 +1522,7 @@ }, { "cell_type": "markdown", - "id": "88f10203", + "id": "0e01c5da", "metadata": { "lines_to_next_cell": 0 }, @@ -1526,7 +1536,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2e79affe", + "id": "a9b88599", "metadata": { "lines_to_next_cell": 0 }, @@ -1546,7 +1556,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70f9846e", + "id": "f402e79a", "metadata": { "lines_to_next_cell": 0 }, @@ -1582,7 +1592,7 @@ }, { "cell_type": "markdown", - "id": "386a9f4c", + "id": "fcefb0bb", "metadata": { "lines_to_next_cell": 0 }, @@ -1596,7 +1606,7 @@ }, { "cell_type": "markdown", - "id": "ae25a656", + "id": "01f9a614", "metadata": { "lines_to_next_cell": 0 }, @@ -1612,7 +1622,7 @@ }, { "cell_type": "markdown", - "id": "53c7b041", + "id": "3bb20dcd", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,30 +1638,19 @@ }, { "cell_type": "markdown", - "id": "c75ed748", + "id": "d546c694", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Part 5: Exploring the Style Space, finding the answer\n", - "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", - "\n", - "Here is an example of two images that are very similar in color, but are of different classes.\n", - "![same_color_diff_class](assets/same_color_diff_class.png)\n", - "While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!\n", - "\n", - "Conversely, here is an example of two images with very different colors, but that are of the same class:\n", - "![same_class_diff_color](assets/same_class_diff_color.png)\n", - "Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!\n", - "\n", - "\n", "So color is important... but not always? What's going on!?\n", "There is a final piece of information that we can use to solve the puzzle: the style space." ] }, { "cell_type": "markdown", - "id": "45dcb17b", + "id": "bc3ed726", "metadata": {}, "source": [ "

Task 5.1: Explore the style space

\n", @@ -1663,7 +1662,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11309f0a", + "id": "718ce53b", "metadata": {}, "outputs": [], "source": [ @@ -1698,7 +1697,7 @@ }, { "cell_type": "markdown", - "id": "e7e0f8d4", + "id": "ca7a48a9", "metadata": { "lines_to_next_cell": 0 }, @@ -1714,7 +1713,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cc183548", + "id": "e814f376", "metadata": { "lines_to_next_cell": 0 }, @@ -1741,7 +1740,7 @@ }, { "cell_type": "markdown", - "id": "9f34f021", + "id": "bbdc2fa9", "metadata": { "lines_to_next_cell": 0 }, @@ -1755,7 +1754,7 @@ }, { "cell_type": "markdown", - "id": "5d29071b", + "id": "3fe20ca4", "metadata": { "lines_to_next_cell": 0 }, @@ -1772,7 +1771,7 @@ { "cell_type": "code", "execution_count": null, - "id": "366dd45c", + "id": "ff536dd4", "metadata": {}, "outputs": [], "source": [ @@ -1794,7 +1793,7 @@ }, { "cell_type": "markdown", - "id": "b657a0f5", + "id": "d19016ff", "metadata": {}, "source": [ "

Questions

\n", @@ -1806,7 +1805,7 @@ }, { "cell_type": "markdown", - "id": "35b0a886", + "id": "24c11487", "metadata": {}, "source": [ "

Checkpoint 5

\n", @@ -1824,7 +1823,7 @@ }, { "cell_type": "markdown", - "id": "c7f12c18", + "id": "36582123", "metadata": {}, "source": [ "# Bonus!\n", diff --git a/solution.ipynb b/solution.ipynb index b8d9964..e1f331f 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2cec6e2e", + "id": "1b3bf19e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "c9faff54", + "id": "067e738b", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "8b2714f2", + "id": "b913b71a", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "196887ff", + "id": "98283047", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "dc029ece", + "id": "1afbe1b3", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f76a8a6", + "id": "5bef908b", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "9ec21db3", + "id": "3e27811b", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "06793f28", + "id": "c554ff65", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "becee066", + "id": "885618d6", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "e0d3e182", + "id": "9460ef99", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0bd2cb1e", + "id": "e35f652a", "metadata": { "lines_to_next_cell": 0 }, @@ -194,7 +194,7 @@ }, { "cell_type": "markdown", - "id": "88a42af2", + "id": "714d0118", "metadata": { "lines_to_next_cell": 0 }, @@ -211,7 +211,7 @@ }, { "cell_type": "markdown", - "id": "72fbabcd", + "id": "2b3080ec", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -221,7 +221,7 @@ }, { "cell_type": "markdown", - "id": "5a41df80", + "id": "42bf258c", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -234,7 +234,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c7615db3", + "id": "9dc58dab", "metadata": { "tags": [] }, @@ -252,7 +252,7 @@ }, { "cell_type": "markdown", - "id": "f8bc4307", + "id": "50185c99", "metadata": { "tags": [] }, @@ -268,7 +268,7 @@ { "cell_type": "code", "execution_count": null, - "id": "300ad40e", + "id": "a8bce89a", "metadata": { "tags": [ "solution" @@ -292,7 +292,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ebbe326", + "id": "eed52045", "metadata": { "tags": [] }, @@ -305,7 +305,7 @@ }, { "cell_type": "markdown", - "id": "2de7d44e", + "id": "be5a007b", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -317,7 +317,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c484acb2", + "id": "5c0780f6", "metadata": { "tags": [] }, @@ -345,7 +345,7 @@ { "cell_type": "code", "execution_count": null, - "id": "898595ea", + "id": "30041689", "metadata": { "tags": [] }, @@ -358,7 +358,7 @@ }, { "cell_type": "markdown", - "id": "5b66372f", + "id": "c0d46960", "metadata": { "lines_to_next_cell": 2 }, @@ -372,7 +372,7 @@ }, { "cell_type": "markdown", - "id": "f26f6f42", + "id": "477cd87c", "metadata": { "lines_to_next_cell": 0 }, @@ -385,7 +385,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69cf5669", + "id": "ed247a59", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +410,7 @@ }, { "cell_type": "markdown", - "id": "9642568b", + "id": "746de08d", "metadata": { "lines_to_next_cell": 0 }, @@ -424,7 +424,7 @@ }, { "cell_type": "markdown", - "id": "2891c43d", + "id": "8c5ea46b", "metadata": {}, "source": [ "\n", @@ -450,7 +450,7 @@ }, { "cell_type": "markdown", - "id": "357966b5", + "id": "eaacace4", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -462,7 +462,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99c5b515", + "id": "35062d2c", "metadata": { "tags": [ "solution" @@ -488,7 +488,7 @@ }, { "cell_type": "markdown", - "id": "f9b5fba2", + "id": "bb810907", "metadata": { "tags": [] }, @@ -502,7 +502,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ba867e8", + "id": "367cb0b3", "metadata": { "tags": [ "solution" @@ -532,7 +532,7 @@ }, { "cell_type": "markdown", - "id": "ccedb79c", + "id": "342b9336", "metadata": { "tags": [] }, @@ -548,7 +548,7 @@ }, { "cell_type": "markdown", - "id": "68893afb", + "id": "96d713fd", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -562,7 +562,7 @@ }, { "cell_type": "markdown", - "id": "849fa319", + "id": "f6118ade", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -582,7 +582,7 @@ }, { "cell_type": "markdown", - "id": "876bdd3b", + "id": "12781ee6", "metadata": { "lines_to_next_cell": 0 }, @@ -610,7 +610,7 @@ }, { "cell_type": "markdown", - "id": "5a1c2a34", + "id": "6dd2e900", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -633,7 +633,7 @@ { "cell_type": "code", "execution_count": null, - "id": "85aa4a15", + "id": "e489a340", "metadata": {}, "outputs": [], "source": [ @@ -665,7 +665,7 @@ }, { "cell_type": "markdown", - "id": "2b95f3e8", + "id": "d77aa8d1", "metadata": { "lines_to_next_cell": 0 }, @@ -680,7 +680,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff8cbb81", + "id": "8ec9d68a", "metadata": { "tags": [ "solution" @@ -697,7 +697,7 @@ }, { "cell_type": "markdown", - "id": "9c976d43", + "id": "50f5d295", "metadata": { "tags": [] }, @@ -712,7 +712,7 @@ }, { "cell_type": "markdown", - "id": "881575de", + "id": "f4dd16b3", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -729,7 +729,7 @@ { "cell_type": "code", "execution_count": null, - "id": "de64cf3c", + "id": "7152c016", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "d29d8d58", + "id": "49ff9e3a", "metadata": { "lines_to_next_cell": 0 }, @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f51b15fc", + "id": "b0f0fb98", "metadata": {}, "outputs": [], "source": [ @@ -764,7 +764,7 @@ }, { "cell_type": "markdown", - "id": "73814dae", + "id": "23b21a32", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -782,7 +782,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5fc36682", + "id": "29678f3b", "metadata": { "lines_to_next_cell": 0 }, @@ -794,7 +794,7 @@ }, { "cell_type": "markdown", - "id": "e07edb85", + "id": "6c069dc6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -813,7 +813,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5082b8ab", + "id": "321515da", "metadata": {}, "outputs": [], "source": [ @@ -822,7 +822,7 @@ }, { "cell_type": "markdown", - "id": "43ef8a2d", + "id": "060fd784", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "468fe88f", + "id": "8c85a285", "metadata": {}, "outputs": [], "source": [ @@ -847,7 +847,7 @@ }, { "cell_type": "markdown", - "id": "6dd0f856", + "id": "511db0c6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -859,7 +859,7 @@ { "cell_type": "code", "execution_count": null, - "id": "89da32d6", + "id": "6c3dfe27", "metadata": {}, "outputs": [], "source": [ @@ -872,7 +872,7 @@ }, { "cell_type": "markdown", - "id": "2b894eac", + "id": "aca01927", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -886,7 +886,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b25e284e", + "id": "c0bb0943", "metadata": {}, "outputs": [], "source": [ @@ -898,7 +898,7 @@ }, { "cell_type": "markdown", - "id": "72125f93", + "id": "1c17e308", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -918,7 +918,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41cf9baf", + "id": "c378d7f9", "metadata": {}, "outputs": [], "source": [ @@ -942,7 +942,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5b1fb2e", + "id": "07f4fecf", "metadata": { "lines_to_next_cell": 2 }, @@ -954,7 +954,7 @@ }, { "cell_type": "markdown", - "id": "421a4724", + "id": "452728c7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -976,7 +976,7 @@ }, { "cell_type": "markdown", - "id": "d375c071", + "id": "9a73a3eb", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -988,7 +988,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c79ac460", + "id": "a32e6ffb", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1058,7 +1058,7 @@ }, { "cell_type": "markdown", - "id": "bf2700e1", + "id": "2b1901cd", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1070,7 +1070,7 @@ { "cell_type": "code", "execution_count": null, - "id": "93af3047", + "id": "32eebf00", "metadata": {}, "outputs": [], "source": [ @@ -1086,7 +1086,7 @@ }, { "cell_type": "markdown", - "id": "65452041", + "id": "fd98dc27", "metadata": { "tags": [] }, @@ -1101,7 +1101,7 @@ }, { "cell_type": "markdown", - "id": "db668256", + "id": "c74e6127", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1113,7 +1113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f76c5e65", + "id": "d38a9dfd", "metadata": {}, "outputs": [], "source": [ @@ -1135,7 +1135,7 @@ }, { "cell_type": "markdown", - "id": "573bb9d9", + "id": "bd940bab", "metadata": { "tags": [] }, @@ -1151,7 +1151,7 @@ }, { "cell_type": "markdown", - "id": "26dfd88b", + "id": "24e48cf3", "metadata": { "tags": [] }, @@ -1161,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "f48126ab", + "id": "1f56a54f", "metadata": { "tags": [] }, @@ -1178,7 +1178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4305090b", + "id": "8f572d41", "metadata": { "title": "Loading the test dataset" }, @@ -1198,7 +1198,7 @@ }, { "cell_type": "markdown", - "id": "154db796", + "id": "dc260633", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1210,7 +1210,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebbcd12d", + "id": "0429a766", "metadata": {}, "outputs": [], "source": [ @@ -1223,7 +1223,7 @@ }, { "cell_type": "markdown", - "id": "61d5d4bc", + "id": "80f4b326", "metadata": { "lines_to_next_cell": 0 }, @@ -1233,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "7ae07350", + "id": "8443a95a", "metadata": { "lines_to_next_cell": 0 }, @@ -1251,7 +1251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ede3165e", + "id": "e448e313", "metadata": { "tags": [ "solution" @@ -1288,7 +1288,7 @@ }, { "cell_type": "markdown", - "id": "bebb99c9", + "id": "9a2a6a90", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1300,7 +1300,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c87df0c0", + "id": "ca2d44e7", "metadata": {}, "outputs": [], "source": [ @@ -1313,7 +1313,7 @@ }, { "cell_type": "markdown", - "id": "2fbc38e9", + "id": "63966675", "metadata": { "tags": [] }, @@ -1328,7 +1328,7 @@ }, { "cell_type": "markdown", - "id": "baf8c83a", + "id": "413852c4", "metadata": { "tags": [] }, @@ -1339,7 +1339,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2465c04b", + "id": "65d61fc4", "metadata": {}, "outputs": [], "source": [ @@ -1353,7 +1353,7 @@ }, { "cell_type": "markdown", - "id": "35613805", + "id": "b926cdc6", "metadata": { "tags": [] }, @@ -1368,7 +1368,7 @@ }, { "cell_type": "markdown", - "id": "321fcad6", + "id": "4aaf7349", "metadata": { "lines_to_next_cell": 0 }, @@ -1383,7 +1383,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c45c814", + "id": "ac005791", "metadata": {}, "outputs": [], "source": [ @@ -1404,7 +1404,7 @@ { "cell_type": "code", "execution_count": null, - "id": "584f13f4", + "id": "8f3faac4", "metadata": { "title": "Another visualization function" }, @@ -1433,7 +1433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df33dd4a", + "id": "956658ca", "metadata": { "lines_to_next_cell": 0 }, @@ -1441,7 +1441,7 @@ "source": [ "for idx in range(batch_size):\n", " print(\"Source class:\", y[idx].item())\n", - " print(\"Target class:\", 0)\n", + " print(\"Target class:\", target_class)\n", " visualize_color_attribution_and_counterfactual(\n", " attributions[idx].cpu().numpy(), x[idx].cpu().numpy(), x_fake[idx].cpu().numpy()\n", " )" @@ -1449,7 +1449,7 @@ }, { "cell_type": "markdown", - "id": "f2bdfa5d", + "id": "8df09e0c", "metadata": { "lines_to_next_cell": 0 }, @@ -1465,11 +1465,21 @@ }, { "cell_type": "markdown", - "id": "da87745e", + "id": "59557649", "metadata": { "lines_to_next_cell": 0 }, "source": [ + "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", + "\n", + "Here are two examples of image-counterfactual-attribution triplets.\n", + "You'll notice that they are *very* similar in every way! But one set is different classes, and one set is the same class!\n", + "\n", + "![same_class](assets/same_class.png)\n", + "![diff_class](assets/diff_class.png)\n", + "\n", + "We are missing a crucial step of the explanation pipeline: a quantification of how the class changes over the interpolation. \n", + "\n", "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n", "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n", "Here we have a much simpler task so we have some advantages:\n", @@ -1480,7 +1490,7 @@ }, { "cell_type": "markdown", - "id": "88f10203", + "id": "0e01c5da", "metadata": { "lines_to_next_cell": 0 }, @@ -1494,7 +1504,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2e79affe", + "id": "a9b88599", "metadata": { "lines_to_next_cell": 0 }, @@ -1514,7 +1524,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70f9846e", + "id": "f402e79a", "metadata": { "lines_to_next_cell": 0 }, @@ -1550,7 +1560,7 @@ }, { "cell_type": "markdown", - "id": "386a9f4c", + "id": "fcefb0bb", "metadata": { "lines_to_next_cell": 0 }, @@ -1564,7 +1574,7 @@ }, { "cell_type": "markdown", - "id": "ae25a656", + "id": "01f9a614", "metadata": { "lines_to_next_cell": 0 }, @@ -1580,7 +1590,7 @@ }, { "cell_type": "markdown", - "id": "53c7b041", + "id": "3bb20dcd", "metadata": { "lines_to_next_cell": 0 }, @@ -1596,30 +1606,19 @@ }, { "cell_type": "markdown", - "id": "c75ed748", + "id": "d546c694", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Part 5: Exploring the Style Space, finding the answer\n", - "By now you will have hopefully noticed that it isn't the exact color of the image that determines its class, but that two images with a very similar color can be of different classes!\n", - "\n", - "Here is an example of two images that are very similar in color, but are of different classes.\n", - "![same_color_diff_class](assets/same_color_diff_class.png)\n", - "While both of the images are yellow, the attribution tells us (if you squint!) that one of the yellows has slightly more blue in it!\n", - "\n", - "Conversely, here is an example of two images with very different colors, but that are of the same class:\n", - "![same_class_diff_color](assets/same_class_diff_color.png)\n", - "Here the attribution is empty! Using the discriminative attribution we can see that the significant color change doesn't matter at all!\n", - "\n", - "\n", "So color is important... but not always? What's going on!?\n", "There is a final piece of information that we can use to solve the puzzle: the style space." ] }, { "cell_type": "markdown", - "id": "45dcb17b", + "id": "bc3ed726", "metadata": {}, "source": [ "

Task 5.1: Explore the style space

\n", @@ -1631,7 +1630,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11309f0a", + "id": "718ce53b", "metadata": {}, "outputs": [], "source": [ @@ -1666,7 +1665,7 @@ }, { "cell_type": "markdown", - "id": "e7e0f8d4", + "id": "ca7a48a9", "metadata": { "lines_to_next_cell": 0 }, @@ -1682,7 +1681,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cc183548", + "id": "e814f376", "metadata": { "lines_to_next_cell": 0 }, @@ -1709,7 +1708,7 @@ }, { "cell_type": "markdown", - "id": "9f34f021", + "id": "bbdc2fa9", "metadata": { "lines_to_next_cell": 0 }, @@ -1723,7 +1722,7 @@ }, { "cell_type": "markdown", - "id": "5d29071b", + "id": "3fe20ca4", "metadata": { "lines_to_next_cell": 0 }, @@ -1740,7 +1739,7 @@ { "cell_type": "code", "execution_count": null, - "id": "366dd45c", + "id": "ff536dd4", "metadata": {}, "outputs": [], "source": [ @@ -1762,7 +1761,7 @@ }, { "cell_type": "markdown", - "id": "b657a0f5", + "id": "d19016ff", "metadata": {}, "source": [ "

Questions

\n", @@ -1774,7 +1773,7 @@ }, { "cell_type": "markdown", - "id": "35b0a886", + "id": "24c11487", "metadata": {}, "source": [ "

Checkpoint 5

\n", @@ -1792,7 +1791,7 @@ }, { "cell_type": "markdown", - "id": "c7f12c18", + "id": "36582123", "metadata": {}, "source": [ "# Bonus!\n", @@ -1807,7 +1806,7 @@ }, { "cell_type": "markdown", - "id": "93aa15d5", + "id": "44f3ac5f", "metadata": { "tags": [ "solution" @@ -1821,7 +1820,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ac995812", + "id": "3144a680", "metadata": { "tags": [ "solution"