diff --git a/exercise.ipynb b/exercise.ipynb index dd8f750..557e3a5 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "46b2ed85", + "id": "d5acca24", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "3ca92ed6", + "id": "cb662602", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "5452b543", + "id": "7af25d33", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6a6be8", + "id": "f2d6db31", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "d521dde9", + "id": "3ef3bfa9", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d1b1c06", + "id": "02be6af2", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "90516193", + "id": "6a76f50d", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "8ec094ee", + "id": "fae7773e", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9709f106", + "id": "01199870", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "e8d2799d", + "id": "137cd457", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e0a495ea", + "id": "3c09237d", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "3e5df200", + "id": "92af9975", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "f3b1fb74", + "id": "c6d16b90", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df74345f", + "id": "29c8613f", "metadata": { "tags": [] }, @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "43f8c1b4", + "id": "ba7abfbe", "metadata": { "tags": [] }, @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "327a7562", + "id": "a842f6b9", "metadata": { "tags": [ "task" @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9907f91b", + "id": "d4148689", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "61bab00d", + "id": "276c658d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "93d8ac21", + "id": "30d569fb", "metadata": { "tags": [] }, @@ -324,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a2476939", + "id": "570a7ccd", "metadata": { "tags": [] }, @@ -337,7 +337,7 @@ }, { "cell_type": "markdown", - "id": "b7e03da8", + "id": "4f13b2f4", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "627449ca", + "id": "8c8790d3", "metadata": { "lines_to_next_cell": 0 }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c1cee24", + "id": "ab2e8a0d", "metadata": {}, "outputs": [], "source": [ @@ -376,7 +376,7 @@ " ax1.imshow(original_image)\n", " ax1.set_title(\"Image\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()\n", @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "e887a55c", + "id": "b031ab39", "metadata": { "lines_to_next_cell": 0 }, @@ -403,7 +403,7 @@ }, { "cell_type": "markdown", - "id": "3f5d9121", + "id": "e0a15147", "metadata": {}, "source": [ "\n", @@ -429,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "b0ea4cb1", + "id": "10e59bd9", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0ac809d", + "id": "939622d5", "metadata": { "tags": [ "task" @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "9163c4d8", + "id": "82e2a1b3", "metadata": { "tags": [] }, @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12152fe9", + "id": "8fae20a2", "metadata": { "tags": [ "task" @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "cb79d651", + "id": "446d0aab", "metadata": { "tags": [] }, @@ -515,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "d102ec55", + "id": "c0107144", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -529,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "758de8e5", + "id": "9a69110b", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "9dd5ce8d", + "id": "d6cd057a", "metadata": { "lines_to_next_cell": 0 }, @@ -577,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "4497f894", + "id": "003cd7d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -600,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4eba79a8", + "id": "c99ee37d", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "4e369a2c", + "id": "9a1ba5d0", "metadata": { "lines_to_next_cell": 0 }, @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7dfcc3a5", + "id": "0e255fc6", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -668,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "2f351532", + "id": "d0d5441a", "metadata": { "tags": [] }, @@ -683,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "fef243a6", + "id": "1a5da3b0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -700,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "84848da5", + "id": "a17d88eb", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -714,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "331a5aa9", + "id": "6152ae96", "metadata": { "lines_to_next_cell": 0 }, @@ -725,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5da1667d", + "id": "47ccc999", "metadata": {}, "outputs": [], "source": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "52c90cba", + "id": "235ecf34", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bbbf486c", + "id": "ecd33083", "metadata": { "lines_to_next_cell": 0 }, @@ -765,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "900d4b51", + "id": "46499a30", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -784,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab86e2bc", + "id": "a75d1d8e", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "bb7046e9", + "id": "6f1e7957", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "049d4a98", + "id": "48c11783", "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "8c3e7170", + "id": "7a403168", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "209395dd", + "id": "197f0b55", "metadata": {}, "outputs": [], "source": [ @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "2513fff6", + "id": "14642c7c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6413b39", + "id": "8567a4c7", "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "1022dd64", + "id": "d617a57b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3765f525", + "id": "152d1fc2", "metadata": {}, "outputs": [], "source": [ @@ -913,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6dc12c2", + "id": "1beae4d6", "metadata": {}, "outputs": [], "source": [ @@ -923,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "f572e6ab", + "id": "6fba9909", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -945,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "0725d30d", + "id": "07dd61b7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -957,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6feaf7e3", + "id": "df4f3916", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "cf7e2d33", + "id": "0574b61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1080,7 +1080,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8f84461", + "id": "94723f87", "metadata": {}, "outputs": [], "source": [ @@ -1096,7 +1096,7 @@ }, { "cell_type": "markdown", - "id": "425b6eb8", + "id": "a734af98", "metadata": { "tags": [] }, @@ -1111,7 +1111,7 @@ }, { "cell_type": "markdown", - "id": "0f81593a", + "id": "34e2b5cd", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1123,7 +1123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8fb9650f", + "id": "22f64934", "metadata": {}, "outputs": [], "source": [ @@ -1143,19 +1143,9 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf9cfe92", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", - "id": "c31b5230", + "id": "d8dfbe12", "metadata": { "tags": [] }, @@ -1171,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "afb5d3fe", + "id": "32991250", "metadata": { "tags": [] }, @@ -1181,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "96e6c29a", + "id": "216383e1", "metadata": { "tags": [] }, @@ -1198,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8fcd2b8a", + "id": "54fd4016", "metadata": { "title": "Loading the test dataset" }, @@ -1218,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "e58a66c1", + "id": "1dbcc267", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1230,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b6984bec", + "id": "04b20178", "metadata": {}, "outputs": [], "source": [ @@ -1243,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "09a033e4", + "id": "ea19f383", "metadata": { "lines_to_next_cell": 0 }, @@ -1253,12 +1243,12 @@ }, { "cell_type": "markdown", - "id": "8d6dafc8", + "id": "e10ea481", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

Task 4: Create counterfactuals

\n", + "

Task 4.1: Create counterfactuals

\n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
    \n", @@ -1271,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "98d8e7f1", + "id": "24f75a7b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1307,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "5ed9dc6b", + "id": "44a96f81", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1319,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83220035", + "id": "83523c2f", "metadata": {}, "outputs": [], "source": [ @@ -1332,7 +1322,7 @@ }, { "cell_type": "markdown", - "id": "48b400d0", + "id": "bee1aa55", "metadata": { "tags": [] }, @@ -1347,7 +1337,7 @@ }, { "cell_type": "markdown", - "id": "28a66131", + "id": "dbac0e2e", "metadata": { "tags": [] }, @@ -1358,7 +1348,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ba72cac7", + "id": "7d263995", "metadata": {}, "outputs": [], "source": [ @@ -1372,7 +1362,7 @@ }, { "cell_type": "markdown", - "id": "7919e4d2", + "id": "60081f81", "metadata": { "tags": [] }, @@ -1387,7 +1377,7 @@ }, { "cell_type": "markdown", - "id": "5727914c", + "id": "a037d368", "metadata": { "lines_to_next_cell": 0 }, @@ -1402,15 +1392,16 @@ { "cell_type": "code", "execution_count": null, - "id": "0db3b836", + "id": "308d7f21", "metadata": {}, "outputs": [], "source": [ + "target_class = 0\n", "batch_size = 4\n", "batch = [random_test_mnist[i] for i in range(batch_size)]\n", "x = torch.stack([b[0] for b in batch])\n", "y = torch.tensor([b[1] for b in batch])\n", - "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x_fake = torch.tensor(counterfactuals[target_class, :batch_size])\n", "x = x.to(device).float()\n", "y = y.to(device)\n", "x_fake = x_fake.to(device).float()\n", @@ -1422,7 +1413,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eda02588", + "id": "12c79b2d", "metadata": { "title": "Another visualization function" }, @@ -1442,7 +1433,7 @@ " ax1.imshow(counterfactual_image)\n", " ax1.set_title(\"Counterfactual\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()" @@ -1451,7 +1442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d6fc8bf", + "id": "aa24deb7", "metadata": { "lines_to_next_cell": 0 }, @@ -1467,7 +1458,7 @@ }, { "cell_type": "markdown", - "id": "b723e27b", + "id": "1726fcb6", "metadata": { "lines_to_next_cell": 0 }, @@ -1483,7 +1474,122 @@ }, { "cell_type": "markdown", - "id": "66ec380e", + "id": "93aff28e", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n", + "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n", + "Here we have a much simpler task so we have some advantages:\n", + "- The counterfactuals are perfect! They already change the bare minimum (trust me).\n", + "- The changes are not objects, but colors.\n", + "As such, we will do a much simpler linear interpolation between the images." + ] + }, + { + "cell_type": "markdown", + "id": "0389b58c", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

    Task 4.2: Interpolation

    \n", + "Let's interpolate between the original image and the counterfactual image.\n", + "We will create 10 images in between the two, and classify them.\n", + "
    " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61b9f9a0", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "num_interpolations = 15\n", + "alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]\n", + "interpolated_images = [\n", + " alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)\n", + "]\n", + "interpolated_images = torch.stack(interpolated_images)\n", + "interpolated_classifications = [\n", + " model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e70b512", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Plot the results\n", + "idx = 0\n", + "fig, axs = plt.subplots(\n", + " batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)\n", + ")\n", + "for idx in range(batch_size):\n", + " # Plot the original image\n", + " axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, 0].axis(\"off\")\n", + " # Use the class as the title\n", + " axs[idx, 0].set_title(f\"Image: y={y[idx].item()}\")\n", + " # Plot the counterfactual image\n", + " axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, -1].axis(\"off\")\n", + " # Use the target class as the title\n", + " axs[idx, -1].set_title(f\"CF: y={target_class}\")\n", + " for i, ax in enumerate(axs[idx][1:-1]):\n", + " ax.imshow(\n", + " np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))\n", + " )\n", + " ax.axis(\"off\")\n", + " classification = torch.softmax(interpolated_classifications[i][idx], dim=0)\n", + " # Plot the classification as the title in order source classification | target classification\n", + " ax.set_title(\n", + " f\"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "e9126f01", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Take some time to look at the plot we just made.\n", + "On the very left are the images we randomly chose - it's class is shown in the title.\n", + "On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.\n", + "In between are the interpolated images - their title shows their classification as \"source classification | target classification\".\n", + "This is a lot to take in, so take your time! Once you're ready, we can move on to the questions." + ] + }, + { + "cell_type": "markdown", + "id": "d50ea18c", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

    Questions

    \n", + "
      \n", + "
    • Do the images change smoothly from one class to another?
    • \n", + "
    • Can you see any patterns in the changes?
    • \n", + "
    • What happens when the original image and the counterfactual image are of the same class?
    • \n", + "
    • Based on this, would you trust this classifier on unseen images more or less than you did before?
    • \n", + "
    " + ] + }, + { + "cell_type": "markdown", + "id": "a93a5437", "metadata": { "lines_to_next_cell": 0 }, @@ -1493,12 +1599,13 @@ "- Created a StarGAN that can change the class of an image\n", "- Evaluated the StarGAN on unseen data\n", "- Used the StarGAN to create counterfactual images\n", - "- Used the counterfactual images to highlight the differences between classes\n" + "- Used the counterfactual images to highlight the differences between classes\n", + "- Interpolated between the images to see how the classifier behaves\n" ] }, { "cell_type": "markdown", - "id": "7592e83d", + "id": "b8f44b99", "metadata": { "lines_to_next_cell": 0 }, @@ -1521,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "a7400c5c", + "id": "aafe8844", "metadata": {}, "source": [ "

    Task 5.1: Explore the style space

    \n", @@ -1533,7 +1640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5493fd4", + "id": "0128d7bd", "metadata": {}, "outputs": [], "source": [ @@ -1568,7 +1675,7 @@ }, { "cell_type": "markdown", - "id": "d57f4253", + "id": "609028bd", "metadata": { "lines_to_next_cell": 0 }, @@ -1584,7 +1691,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4160a32f", + "id": "8b7185c5", "metadata": { "lines_to_next_cell": 0 }, @@ -1611,7 +1718,7 @@ }, { "cell_type": "markdown", - "id": "650c6766", + "id": "772e4231", "metadata": { "lines_to_next_cell": 0 }, @@ -1625,7 +1732,7 @@ }, { "cell_type": "markdown", - "id": "f88d9484", + "id": "7d159c2b", "metadata": { "lines_to_next_cell": 0 }, @@ -1642,7 +1749,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b57d333", + "id": "13eb454f", "metadata": {}, "outputs": [], "source": [ @@ -1665,7 +1772,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4bb3433e", + "id": "3ddc6a07", "metadata": { "lines_to_next_cell": 0 }, @@ -1674,7 +1781,7 @@ }, { "cell_type": "markdown", - "id": "516e0112", + "id": "303338b1", "metadata": {}, "source": [ "

    Questions

    \n", @@ -1686,7 +1793,7 @@ }, { "cell_type": "markdown", - "id": "63cef16c", + "id": "e18da1e0", "metadata": {}, "source": [ "

    Checkpoint 5

    \n", @@ -1701,6 +1808,27 @@ "If you have any questions, feel free to ask them in the chat!\n", "And check the Solutions exercise for a definite answer to how these classes are defined!" ] + }, + { + "cell_type": "markdown", + "id": "b74f495b", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Bonus!\n", + "If you have extra time, you can try to break the StarGAN!\n", + "There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?\n", + "Some things you might want to try:\n", + "- What happens if you don't use the EMA model?\n", + "- What happens if you change the learning rates?\n", + "- What happens if you add a Sigmoid activation to the output of the style encoder?\n", + "See what else you can think of, and see how finnicky training a GAN can be!\n", + "\n", + "# %% [markdown] tags=[\"solution\"]\n", + "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n", + "Check your style space again to see if you can see the patterns now!" + ] } ], "metadata": { diff --git a/solution.ipynb b/solution.ipynb index db548a4..dcdff0d 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "46b2ed85", + "id": "d5acca24", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "3ca92ed6", + "id": "cb662602", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "5452b543", + "id": "7af25d33", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f6a6be8", + "id": "f2d6db31", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "d521dde9", + "id": "3ef3bfa9", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d1b1c06", + "id": "02be6af2", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "90516193", + "id": "6a76f50d", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "8ec094ee", + "id": "fae7773e", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0801e1e5", + "id": "d0b3ad78", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "e8d2799d", + "id": "137cd457", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e0a495ea", + "id": "3c09237d", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "3e5df200", + "id": "92af9975", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "f3b1fb74", + "id": "c6d16b90", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df74345f", + "id": "29c8613f", "metadata": { "tags": [] }, @@ -233,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "43f8c1b4", + "id": "ba7abfbe", "metadata": { "tags": [] }, @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52eb1aa0", + "id": "c4e89e02", "metadata": { "tags": [ "solution" @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9907f91b", + "id": "d4148689", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "61bab00d", + "id": "276c658d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -298,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "93d8ac21", + "id": "30d569fb", "metadata": { "tags": [] }, @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a2476939", + "id": "570a7ccd", "metadata": { "tags": [] }, @@ -339,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "b7e03da8", + "id": "4f13b2f4", "metadata": { "lines_to_next_cell": 2 }, @@ -353,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "627449ca", + "id": "8c8790d3", "metadata": { "lines_to_next_cell": 0 }, @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8c1cee24", + "id": "ab2e8a0d", "metadata": {}, "outputs": [], "source": [ @@ -378,7 +378,7 @@ " ax1.imshow(original_image)\n", " ax1.set_title(\"Image\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()\n", @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "e887a55c", + "id": "b031ab39", "metadata": { "lines_to_next_cell": 0 }, @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "3f5d9121", + "id": "e0a15147", "metadata": {}, "source": [ "\n", @@ -431,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "b0ea4cb1", + "id": "10e59bd9", "metadata": {}, "source": [ "

    Task 2.3: Use random noise as a baseline

    \n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72ceec17", + "id": "57012d02", "metadata": { "tags": [ "solution" @@ -469,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "9163c4d8", + "id": "82e2a1b3", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8ae711b1", + "id": "97f80e4a", "metadata": { "tags": [ "solution" @@ -511,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "cb79d651", + "id": "446d0aab", "metadata": { "tags": [] }, @@ -527,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "d102ec55", + "id": "c0107144", "metadata": {}, "source": [ "

    BONUS Task: Using different attributions.

    \n", @@ -541,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "758de8e5", + "id": "9a69110b", "metadata": {}, "source": [ "

    Checkpoint 2

    \n", @@ -561,7 +561,7 @@ }, { "cell_type": "markdown", - "id": "9dd5ce8d", + "id": "d6cd057a", "metadata": { "lines_to_next_cell": 0 }, @@ -589,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "4497f894", + "id": "003cd7d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -612,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4eba79a8", + "id": "c99ee37d", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "4e369a2c", + "id": "9a1ba5d0", "metadata": { "lines_to_next_cell": 0 }, @@ -659,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69fb25c6", + "id": "23a3a52c", "metadata": { "tags": [ "solution" @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "2f351532", + "id": "d0d5441a", "metadata": { "tags": [] }, @@ -691,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "fef243a6", + "id": "1a5da3b0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -708,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1923e528", + "id": "7488ac02", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "331a5aa9", + "id": "6152ae96", "metadata": { "lines_to_next_cell": 0 }, @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5da1667d", + "id": "47ccc999", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "52c90cba", + "id": "235ecf34", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -761,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bbbf486c", + "id": "ecd33083", "metadata": { "lines_to_next_cell": 0 }, @@ -773,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "900d4b51", + "id": "46499a30", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -792,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab86e2bc", + "id": "a75d1d8e", "metadata": {}, "outputs": [], "source": [ @@ -801,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "bb7046e9", + "id": "6f1e7957", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -817,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "049d4a98", + "id": "48c11783", "metadata": {}, "outputs": [], "source": [ @@ -826,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "8c3e7170", + "id": "7a403168", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "209395dd", + "id": "197f0b55", "metadata": {}, "outputs": [], "source": [ @@ -851,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "2513fff6", + "id": "14642c7c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -865,7 +865,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6413b39", + "id": "8567a4c7", "metadata": {}, "outputs": [], "source": [ @@ -877,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "1022dd64", + "id": "d617a57b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -897,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3765f525", + "id": "152d1fc2", "metadata": {}, "outputs": [], "source": [ @@ -921,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6dc12c2", + "id": "1beae4d6", "metadata": {}, "outputs": [], "source": [ @@ -931,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "f572e6ab", + "id": "6fba9909", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -953,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "0725d30d", + "id": "07dd61b7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -965,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a50de214", + "id": "5e54a534", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "cf7e2d33", + "id": "0574b61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1047,7 +1047,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8f84461", + "id": "94723f87", "metadata": {}, "outputs": [], "source": [ @@ -1063,7 +1063,7 @@ }, { "cell_type": "markdown", - "id": "425b6eb8", + "id": "a734af98", "metadata": { "tags": [] }, @@ -1078,7 +1078,7 @@ }, { "cell_type": "markdown", - "id": "0f81593a", + "id": "34e2b5cd", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1090,7 +1090,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8fb9650f", + "id": "22f64934", "metadata": {}, "outputs": [], "source": [ @@ -1110,19 +1110,9 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf9cfe92", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", - "id": "c31b5230", + "id": "d8dfbe12", "metadata": { "tags": [] }, @@ -1138,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "afb5d3fe", + "id": "32991250", "metadata": { "tags": [] }, @@ -1148,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "96e6c29a", + "id": "216383e1", "metadata": { "tags": [] }, @@ -1165,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8fcd2b8a", + "id": "54fd4016", "metadata": { "title": "Loading the test dataset" }, @@ -1185,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "e58a66c1", + "id": "1dbcc267", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1197,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b6984bec", + "id": "04b20178", "metadata": {}, "outputs": [], "source": [ @@ -1210,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "09a033e4", + "id": "ea19f383", "metadata": { "lines_to_next_cell": 0 }, @@ -1220,12 +1210,12 @@ }, { "cell_type": "markdown", - "id": "8d6dafc8", + "id": "e10ea481", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

    Task 4: Create counterfactuals

    \n", + "

    Task 4.1: Create counterfactuals

    \n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
      \n", @@ -1238,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51f8114a", + "id": "780e6fdd", "metadata": { "tags": [ "solution" @@ -1275,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "5ed9dc6b", + "id": "44a96f81", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1287,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "83220035", + "id": "83523c2f", "metadata": {}, "outputs": [], "source": [ @@ -1300,7 +1290,7 @@ }, { "cell_type": "markdown", - "id": "48b400d0", + "id": "bee1aa55", "metadata": { "tags": [] }, @@ -1315,7 +1305,7 @@ }, { "cell_type": "markdown", - "id": "28a66131", + "id": "dbac0e2e", "metadata": { "tags": [] }, @@ -1326,7 +1316,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ba72cac7", + "id": "7d263995", "metadata": {}, "outputs": [], "source": [ @@ -1340,7 +1330,7 @@ }, { "cell_type": "markdown", - "id": "7919e4d2", + "id": "60081f81", "metadata": { "tags": [] }, @@ -1355,7 +1345,7 @@ }, { "cell_type": "markdown", - "id": "5727914c", + "id": "a037d368", "metadata": { "lines_to_next_cell": 0 }, @@ -1370,15 +1360,16 @@ { "cell_type": "code", "execution_count": null, - "id": "0db3b836", + "id": "308d7f21", "metadata": {}, "outputs": [], "source": [ + "target_class = 0\n", "batch_size = 4\n", "batch = [random_test_mnist[i] for i in range(batch_size)]\n", "x = torch.stack([b[0] for b in batch])\n", "y = torch.tensor([b[1] for b in batch])\n", - "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x_fake = torch.tensor(counterfactuals[target_class, :batch_size])\n", "x = x.to(device).float()\n", "y = y.to(device)\n", "x_fake = x_fake.to(device).float()\n", @@ -1390,7 +1381,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eda02588", + "id": "12c79b2d", "metadata": { "title": "Another visualization function" }, @@ -1410,7 +1401,7 @@ " ax1.imshow(counterfactual_image)\n", " ax1.set_title(\"Counterfactual\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()" @@ -1419,7 +1410,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d6fc8bf", + "id": "aa24deb7", "metadata": { "lines_to_next_cell": 0 }, @@ -1435,7 +1426,7 @@ }, { "cell_type": "markdown", - "id": "b723e27b", + "id": "1726fcb6", "metadata": { "lines_to_next_cell": 0 }, @@ -1451,7 +1442,122 @@ }, { "cell_type": "markdown", - "id": "66ec380e", + "id": "93aff28e", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n", + "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n", + "Here we have a much simpler task so we have some advantages:\n", + "- The counterfactuals are perfect! They already change the bare minimum (trust me).\n", + "- The changes are not objects, but colors.\n", + "As such, we will do a much simpler linear interpolation between the images." + ] + }, + { + "cell_type": "markdown", + "id": "0389b58c", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

      Task 4.2: Interpolation

      \n", + "Let's interpolate between the original image and the counterfactual image.\n", + "We will create 10 images in between the two, and classify them.\n", + "
      " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61b9f9a0", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "num_interpolations = 15\n", + "alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]\n", + "interpolated_images = [\n", + " alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)\n", + "]\n", + "interpolated_images = torch.stack(interpolated_images)\n", + "interpolated_classifications = [\n", + " model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e70b512", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Plot the results\n", + "idx = 0\n", + "fig, axs = plt.subplots(\n", + " batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)\n", + ")\n", + "for idx in range(batch_size):\n", + " # Plot the original image\n", + " axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, 0].axis(\"off\")\n", + " # Use the class as the title\n", + " axs[idx, 0].set_title(f\"Image: y={y[idx].item()}\")\n", + " # Plot the counterfactual image\n", + " axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, -1].axis(\"off\")\n", + " # Use the target class as the title\n", + " axs[idx, -1].set_title(f\"CF: y={target_class}\")\n", + " for i, ax in enumerate(axs[idx][1:-1]):\n", + " ax.imshow(\n", + " np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))\n", + " )\n", + " ax.axis(\"off\")\n", + " classification = torch.softmax(interpolated_classifications[i][idx], dim=0)\n", + " # Plot the classification as the title in order source classification | target classification\n", + " ax.set_title(\n", + " f\"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "e9126f01", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Take some time to look at the plot we just made.\n", + "On the very left are the images we randomly chose - it's class is shown in the title.\n", + "On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.\n", + "In between are the interpolated images - their title shows their classification as \"source classification | target classification\".\n", + "This is a lot to take in, so take your time! Once you're ready, we can move on to the questions." + ] + }, + { + "cell_type": "markdown", + "id": "d50ea18c", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

      Questions

      \n", + "
        \n", + "
      • Do the images change smoothly from one class to another?
      • \n", + "
      • Can you see any patterns in the changes?
      • \n", + "
      • What happens when the original image and the counterfactual image are of the same class?
      • \n", + "
      • Based on this, would you trust this classifier on unseen images more or less than you did before?
      • \n", + "
      " + ] + }, + { + "cell_type": "markdown", + "id": "a93a5437", "metadata": { "lines_to_next_cell": 0 }, @@ -1461,12 +1567,13 @@ "- Created a StarGAN that can change the class of an image\n", "- Evaluated the StarGAN on unseen data\n", "- Used the StarGAN to create counterfactual images\n", - "- Used the counterfactual images to highlight the differences between classes\n" + "- Used the counterfactual images to highlight the differences between classes\n", + "- Interpolated between the images to see how the classifier behaves\n" ] }, { "cell_type": "markdown", - "id": "7592e83d", + "id": "b8f44b99", "metadata": { "lines_to_next_cell": 0 }, @@ -1489,7 +1596,7 @@ }, { "cell_type": "markdown", - "id": "a7400c5c", + "id": "aafe8844", "metadata": {}, "source": [ "

      Task 5.1: Explore the style space

      \n", @@ -1501,7 +1608,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5493fd4", + "id": "0128d7bd", "metadata": {}, "outputs": [], "source": [ @@ -1536,7 +1643,7 @@ }, { "cell_type": "markdown", - "id": "d57f4253", + "id": "609028bd", "metadata": { "lines_to_next_cell": 0 }, @@ -1552,7 +1659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4160a32f", + "id": "8b7185c5", "metadata": { "lines_to_next_cell": 0 }, @@ -1579,7 +1686,7 @@ }, { "cell_type": "markdown", - "id": "650c6766", + "id": "772e4231", "metadata": { "lines_to_next_cell": 0 }, @@ -1593,7 +1700,7 @@ }, { "cell_type": "markdown", - "id": "f88d9484", + "id": "7d159c2b", "metadata": { "lines_to_next_cell": 0 }, @@ -1610,7 +1717,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b57d333", + "id": "13eb454f", "metadata": {}, "outputs": [], "source": [ @@ -1633,7 +1740,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4bb3433e", + "id": "3ddc6a07", "metadata": { "lines_to_next_cell": 0 }, @@ -1642,7 +1749,7 @@ }, { "cell_type": "markdown", - "id": "516e0112", + "id": "303338b1", "metadata": {}, "source": [ "

      Questions

      \n", @@ -1654,7 +1761,7 @@ }, { "cell_type": "markdown", - "id": "63cef16c", + "id": "e18da1e0", "metadata": {}, "source": [ "

      Checkpoint 5

      \n", @@ -1672,14 +1779,21 @@ }, { "cell_type": "markdown", - "id": "e0cb271b", + "id": "b74f495b", "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] + "lines_to_next_cell": 0 }, "source": [ + "# Bonus!\n", + "If you have extra time, you can try to break the StarGAN!\n", + "There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?\n", + "Some things you might want to try:\n", + "- What happens if you don't use the EMA model?\n", + "- What happens if you change the learning rates?\n", + "- What happens if you add a Sigmoid activation to the output of the style encoder?\n", + "See what else you can think of, and see how finnicky training a GAN can be!\n", + "\n", + "# %% [markdown] tags=[\"solution\"]\n", "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n", "Check your style space again to see if you can see the patterns now!" ] @@ -1687,7 +1801,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86681224", + "id": "52f77a00", "metadata": { "tags": [ "solution" diff --git a/solution.py b/solution.py index cd92aa3..e8c7455 100644 --- a/solution.py +++ b/solution.py @@ -222,7 +222,7 @@ def visualize_color_attribution(attribution, original_image): ax1.imshow(original_image) ax1.set_title("Image") ax1.axis("off") - ax2.imshow(np.abs(attribution)) + ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution))) ax2.set_title("Attribution") ax2.axis("off") plt.show() @@ -768,7 +768,6 @@ def copy_parameters(source_model, target_model): ax.axis("off") plt.show() -# %% # %% [markdown] tags=[] #

      Checkpoint 3

      # You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training. @@ -814,7 +813,7 @@ def copy_parameters(source_model, target_model): # %% [markdown] # Now we need to use these prototypes to create counterfactual images! # %% [markdown] -#

      Task 4: Create counterfactuals

      +#

      Task 4.1: Create counterfactuals

      # In the below, we will store the counterfactual images in the `counterfactuals` array. # #
        @@ -917,11 +916,12 @@ def copy_parameters(source_model, target_model): # Let's try putting the two together to see if we can figure out what exactly makes a class. # # %% +target_class = 0 batch_size = 4 batch = [random_test_mnist[i] for i in range(batch_size)] x = torch.stack([b[0] for b in batch]) y = torch.tensor([b[1] for b in batch]) -x_fake = torch.tensor(counterfactuals[0, :batch_size]) +x_fake = torch.tensor(counterfactuals[target_class, :batch_size]) x = x.to(device).float() y = y.to(device) x_fake = x_fake.to(device).float() @@ -945,7 +945,7 @@ def visualize_color_attribution_and_counterfactual( ax1.imshow(counterfactual_image) ax1.set_title("Counterfactual") ax1.axis("off") - ax2.imshow(np.abs(attribution)) + ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution))) ax2.set_title("Attribution") ax2.axis("off") plt.show() @@ -967,12 +967,76 @@ def visualize_color_attribution_and_counterfactual( #
      #
      # %% [markdown] +# In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image. +# This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation. +# Here we have a much simpler task so we have some advantages: +# - The counterfactuals are perfect! They already change the bare minimum (trust me). +# - The changes are not objects, but colors. +# As such, we will do a much simpler linear interpolation between the images. +# %% [markdown] +#

      Task 4.2: Interpolation

      +# Let's interpolate between the original image and the counterfactual image. +# We will create 10 images in between the two, and classify them. +#
      +# %% +num_interpolations = 15 +alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1] +interpolated_images = [ + alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations) +] +interpolated_images = torch.stack(interpolated_images) +interpolated_classifications = [ + model(interpolated_images[idx].to(device)) for idx in range(num_interpolations) +] +# %% +# Plot the results +idx = 0 +fig, axs = plt.subplots( + batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size) +) +for idx in range(batch_size): + # Plot the original image + axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0))) + axs[idx, 0].axis("off") + # Use the class as the title + axs[idx, 0].set_title(f"Image: y={y[idx].item()}") + # Plot the counterfactual image + axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0))) + axs[idx, -1].axis("off") + # Use the target class as the title + axs[idx, -1].set_title(f"CF: y={target_class}") + for i, ax in enumerate(axs[idx][1:-1]): + ax.imshow( + np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0)) + ) + ax.axis("off") + classification = torch.softmax(interpolated_classifications[i][idx], dim=0) + # Plot the classification as the title in order source classification | target classification + ax.set_title( + f"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}" + ) +# %% [markdown] +# Take some time to look at the plot we just made. +# On the very left are the images we randomly chose - it's class is shown in the title. +# On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title. +# In between are the interpolated images - their title shows their classification as "source classification | target classification". +# This is a lot to take in, so take your time! Once you're ready, we can move on to the questions. +# %% [markdown] +#

      Questions

      +#
        +#
      • Do the images change smoothly from one class to another?
      • +#
      • Can you see any patterns in the changes?
      • +#
      • What happens when the original image and the counterfactual image are of the same class?
      • +#
      • Based on this, would you trust this classifier on unseen images more or less than you did before?
      • +#
      +# %% [markdown] #

      Checkpoint 4

      # At this point you have: # - Created a StarGAN that can change the class of an image # - Evaluated the StarGAN on unseen data # - Used the StarGAN to create counterfactual images # - Used the counterfactual images to highlight the differences between classes +# - Interpolated between the images to see how the classifier behaves # # %% [markdown] # # Part 5: Exploring the Style Space, finding the answer @@ -1100,7 +1164,17 @@ def visualize_color_attribution_and_counterfactual( # If you have any questions, feel free to ask them in the chat! # And check the Solutions exercise for a definite answer to how these classes are defined! -# %% [markdown] tags=["solution"] +# %% [markdown] +# # Bonus! +# If you have extra time, you can try to break the StarGAN! +# There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't? +# Some things you might want to try: +# - What happens if you don't use the EMA model? +# - What happens if you change the learning rates? +# - What happens if you add a Sigmoid activation to the output of the style encoder? +# See what else you can think of, and see how finnicky training a GAN can be! + +## %% [markdown] tags=["solution"] # The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter. # Check your style space again to see if you can see the patterns now! # %% tags=["solution"]