From fd13d06b4062d41527fb7a339fbad68e6201ecfe Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 21 Aug 2024 08:56:26 -0400 Subject: [PATCH 1/5] Add a section on hybrids --- solution.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/solution.py b/solution.py index cd92aa3..96296a6 100644 --- a/solution.py +++ b/solution.py @@ -768,7 +768,6 @@ def copy_parameters(source_model, target_model): ax.axis("off") plt.show() -# %% # %% [markdown] tags=[] #

Checkpoint 3

# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training. @@ -814,7 +813,7 @@ def copy_parameters(source_model, target_model): # %% [markdown] # Now we need to use these prototypes to create counterfactual images! # %% [markdown] -#

Task 4: Create counterfactuals

+#

Task 4.1: Create counterfactuals

# In the below, we will store the counterfactual images in the `counterfactuals` array. # #
    @@ -917,11 +916,12 @@ def copy_parameters(source_model, target_model): # Let's try putting the two together to see if we can figure out what exactly makes a class. # # %% +target_class = 0 batch_size = 4 batch = [random_test_mnist[i] for i in range(batch_size)] x = torch.stack([b[0] for b in batch]) y = torch.tensor([b[1] for b in batch]) -x_fake = torch.tensor(counterfactuals[0, :batch_size]) +x_fake = torch.tensor(counterfactuals[target_class, :batch_size]) x = x.to(device).float() y = y.to(device) x_fake = x_fake.to(device).float() @@ -967,12 +967,76 @@ def visualize_color_attribution_and_counterfactual( #
#
# %% [markdown] +# In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image. +# This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation. +# Here we have a much simpler task so we have some advantages: +# - The counterfactuals are perfect! They already change the bare minimum (trust me). +# - The changes are not objects, but colors. +# As such, we will do a much simpler linear interpolation between the images. +# %% [markdown] +#

Task 4.2: Interpolation

+# Let's interpolate between the original image and the counterfactual image. +# We will create 10 images in between the two, and classify them. +#
+# %% +num_interpolations = 15 +alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1] +interpolated_images = [ + alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations) +] +interpolated_images = torch.stack(interpolated_images) +interpolated_classifications = [ + model(interpolated_images[idx].to(device)) for idx in range(num_interpolations) +] +# %% +# Plot the results +idx = 0 +fig, axs = plt.subplots( + batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size) +) +for idx in range(batch_size): + # Plot the original image + axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0))) + axs[idx, 0].axis("off") + # Use the class as the title + axs[idx, 0].set_title(f"Image: y={y[idx].item()}") + # Plot the counterfactual image + axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0))) + axs[idx, -1].axis("off") + # Use the target class as the title + axs[idx, -1].set_title(f"CF: y={target_class}") + for i, ax in enumerate(axs[idx][1:-1]): + ax.imshow( + np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0)) + ) + ax.axis("off") + classification = torch.softmax(interpolated_classifications[i][idx], dim=0) + # Plot the classification as the title in order source classification | target classification + ax.set_title( + f"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}" + ) +# %% [markdown] +# Take some time to look at the plot we just made. +# On the very left are the images we randomly chose - it's class is shown in the title. +# On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title. +# In between are the interpolated images - their title shows their classification as "source classification | target classification". +# This is a lot to take in, so take your time! Once you're ready, we can move on to the questions. +# %% [markdown] +#

Questions

+#
    +#
  • Do the images change smoothly from one class to another?
  • +#
  • Can you see any patterns in the changes?
  • +#
  • What happens when the original image and the counterfactual image are of the same class?
  • +#
  • Based on this, would you trust this classifier on unseen images more or less than you did before?
  • +#
+# %% [markdown] #

Checkpoint 4

# At this point you have: # - Created a StarGAN that can change the class of an image # - Evaluated the StarGAN on unseen data # - Used the StarGAN to create counterfactual images # - Used the counterfactual images to highlight the differences between classes +# - Interpolated between the images to see how the classifier behaves # # %% [markdown] # # Part 5: Exploring the Style Space, finding the answer From dbd2051c7b2038e9012800bbe9067b37cca9e5ae Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 21 Aug 2024 08:57:42 -0400 Subject: [PATCH 2/5] Add attribution normalization --- solution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/solution.py b/solution.py index 96296a6..4d77b71 100644 --- a/solution.py +++ b/solution.py @@ -222,7 +222,7 @@ def visualize_color_attribution(attribution, original_image): ax1.imshow(original_image) ax1.set_title("Image") ax1.axis("off") - ax2.imshow(np.abs(attribution)) + ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution))) ax2.set_title("Attribution") ax2.axis("off") plt.show() @@ -945,7 +945,7 @@ def visualize_color_attribution_and_counterfactual( ax1.imshow(counterfactual_image) ax1.set_title("Counterfactual") ax1.axis("off") - ax2.imshow(np.abs(attribution)) + ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution))) ax2.set_title("Attribution") ax2.axis("off") plt.show() From ab191d6ce77faf23afe09c758659394377ae9f70 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 21 Aug 2024 09:00:12 -0400 Subject: [PATCH 3/5] Add bonus task --- solution.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/solution.py b/solution.py index 4d77b71..e8c7455 100644 --- a/solution.py +++ b/solution.py @@ -1164,7 +1164,17 @@ def visualize_color_attribution_and_counterfactual( # If you have any questions, feel free to ask them in the chat! # And check the Solutions exercise for a definite answer to how these classes are defined! -# %% [markdown] tags=["solution"] +# %% [markdown] +# # Bonus! +# If you have extra time, you can try to break the StarGAN! +# There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't? +# Some things you might want to try: +# - What happens if you don't use the EMA model? +# - What happens if you change the learning rates? +# - What happens if you add a Sigmoid activation to the output of the style encoder? +# See what else you can think of, and see how finnicky training a GAN can be! + +## %% [markdown] tags=["solution"] # The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter. # Check your style space again to see if you can see the patterns now! # %% tags=["solution"] From a440bd603f260a75a963b5b4b266f841fcf79a77 Mon Sep 17 00:00:00 2001 From: adjavon Date: Wed, 21 Aug 2024 18:06:23 +0000 Subject: [PATCH 4/5] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 348 +++++++++++++++++++++++++++++++++---------------- solution.ipynb | 346 +++++++++++++++++++++++++++++++----------------- 2 files changed, 468 insertions(+), 226 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 92007b0..5fbc67f 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "30c11df5", + "id": "f929d6ee", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "ec2899d4", + "id": "2def42e4", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "2c084b97", + "id": "e386c146", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d26a8bb", + "id": "ca1ceaeb", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "f8a5937c", + "id": "073875c1", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9c0ce960", + "id": "634ec90e", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "0cb834e5", + "id": "eca78719", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "a32035d7", + "id": "104e3243", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47684cce", + "id": "7b9b9220", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "6ecddeb8", + "id": "bc54ec3d", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c271ecd9", + "id": "5d8324c9", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "46a684f4", + "id": "1970f094", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "0255c073", + "id": "6dae9d0f", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5b162b7", + "id": "cb8da288", "metadata": { "tags": [] }, @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "6d418ea1", + "id": "c6ea5e99", "metadata": { "tags": [] }, @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5ce086ee", + "id": "79098ac5", "metadata": { "tags": [ "task" @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4ba6b3a", + "id": "db3ecbf5", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "56e432ae", + "id": "a5fc7736", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9561d46f", + "id": "7c214a9a", "metadata": { "tags": [] }, @@ -324,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a55fe8ec", + "id": "256ae2c6", "metadata": { "tags": [] }, @@ -337,7 +337,7 @@ }, { "cell_type": "markdown", - "id": "1d8c03a0", + "id": "e477b37b", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "2a24c70a", + "id": "5660f1bc", "metadata": { "lines_to_next_cell": 0 }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e875faa", + "id": "438c2f08", "metadata": {}, "outputs": [], "source": [ @@ -376,7 +376,7 @@ " ax1.imshow(original_image)\n", " ax1.set_title(\"Image\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()\n", @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "3f73608f", + "id": "9049d759", "metadata": { "lines_to_next_cell": 0 }, @@ -403,7 +403,7 @@ }, { "cell_type": "markdown", - "id": "a8e71c0b", + "id": "57f455ec", "metadata": {}, "source": [ "\n", @@ -429,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "dbb04b6f", + "id": "3eb81cd4", "metadata": {}, "source": [ "

Task 2.3: Use random noise as a baseline

\n", @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2fc8f45c", + "id": "5c64ea11", "metadata": { "tags": [ "task" @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "bf7e934c", + "id": "5811ca2f", "metadata": { "tags": [] }, @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2e14f754", + "id": "d0a216f9", "metadata": { "tags": [ "task" @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "db46361b", + "id": "dee36774", "metadata": { "tags": [] }, @@ -515,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "e9105812", + "id": "23f7f7d8", "metadata": {}, "source": [ "

BONUS Task: Using different attributions.

\n", @@ -529,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "0b2d0f2f", + "id": "dfdad8b4", "metadata": {}, "source": [ "

Checkpoint 2

\n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "531169e5", + "id": "d47d79a3", "metadata": { "lines_to_next_cell": 0 }, @@ -577,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "331e56d6", + "id": "f4e5602e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -600,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "301ee289", + "id": "0ee0fb95", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "4ce023f6", + "id": "4794a5fc", "metadata": { "lines_to_next_cell": 0 }, @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2698719", + "id": "a111d5e8", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -668,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "16f87104", + "id": "f1716b47", "metadata": { "tags": [] }, @@ -683,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "9f1d1149", + "id": "5d54535a", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -700,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14e0c929", + "id": "39f5db85", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -714,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "231a5202", + "id": "1db62898", "metadata": { "lines_to_next_cell": 0 }, @@ -725,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0a2d54d", + "id": "c22a89e4", "metadata": {}, "outputs": [], "source": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "4540ef18", + "id": "2fd71dd0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b9fc6671", + "id": "b7c89896", "metadata": { "lines_to_next_cell": 0 }, @@ -765,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "196daf45", + "id": "38280c63", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -784,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e9ddd12", + "id": "52034753", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "eade7df1", + "id": "501b7a7e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1deb8b8b", + "id": "c26dc615", "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "ba4a7f7f", + "id": "85a5ffc4", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5b3d5dc", + "id": "8822bdf7", "metadata": {}, "outputs": [], "source": [ @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "a029e923", + "id": "49dff835", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54b4de87", + "id": "bef2d8e0", "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "014e484e", + "id": "6cd15802", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6344c83", + "id": "364c8d90", "metadata": {}, "outputs": [], "source": [ @@ -913,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08b7b3af", + "id": "b0c26036", "metadata": {}, "outputs": [], "source": [ @@ -923,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "23fbf680", + "id": "80c55b83", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -945,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "9cb8281d", + "id": "d6935683", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -957,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3b01306d", + "id": "5a11de5e", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "4c25819b", + "id": "580be770", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1080,7 +1080,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d64d32d", + "id": "ff582fa2", "metadata": {}, "outputs": [], "source": [ @@ -1096,7 +1096,7 @@ }, { "cell_type": "markdown", - "id": "326ba2b5", + "id": "6f66943c", "metadata": { "tags": [] }, @@ -1111,7 +1111,7 @@ }, { "cell_type": "markdown", - "id": "3e58ca01", + "id": "b75e3ea6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1123,7 +1123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c522efa", + "id": "e4661010", "metadata": {}, "outputs": [], "source": [ @@ -1143,19 +1143,9 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "30b6dac9", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", - "id": "a3ecbc7b", + "id": "bd665cb6", "metadata": { "tags": [] }, @@ -1171,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "e6bdaecb", + "id": "657df1b2", "metadata": { "tags": [] }, @@ -1181,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "7f994579", + "id": "62ee3d1f", "metadata": { "tags": [] }, @@ -1198,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4e4fe83e", + "id": "340f5689", "metadata": { "title": "Loading the test dataset" }, @@ -1218,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "049a6b22", + "id": "fac75d37", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1230,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "639f37e2", + "id": "9a901d76", "metadata": {}, "outputs": [], "source": [ @@ -1243,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "02cb705b", + "id": "194b5756", "metadata": { "lines_to_next_cell": 0 }, @@ -1253,12 +1243,12 @@ }, { "cell_type": "markdown", - "id": "f41a6ce5", + "id": "c6b2d5d5", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

Task 4: Create counterfactuals

\n", + "

Task 4.1: Create counterfactuals

\n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
    \n", @@ -1271,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "282f8858", + "id": "bbe08ad9", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1307,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "ebffc15f", + "id": "4c1cc8e2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1319,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baac8071", + "id": "d2f3e241", "metadata": {}, "outputs": [], "source": [ @@ -1332,7 +1322,7 @@ }, { "cell_type": "markdown", - "id": "88e7ea0c", + "id": "f085ddbd", "metadata": { "tags": [] }, @@ -1347,7 +1337,7 @@ }, { "cell_type": "markdown", - "id": "25972c49", + "id": "3713cc97", "metadata": { "tags": [] }, @@ -1358,7 +1348,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12d49576", + "id": "d4d09ab3", "metadata": {}, "outputs": [], "source": [ @@ -1372,7 +1362,7 @@ }, { "cell_type": "markdown", - "id": "8e6f04f3", + "id": "d8071f41", "metadata": { "tags": [] }, @@ -1387,7 +1377,7 @@ }, { "cell_type": "markdown", - "id": "50728ff2", + "id": "e8a1bdd8", "metadata": { "lines_to_next_cell": 0 }, @@ -1402,15 +1392,16 @@ { "cell_type": "code", "execution_count": null, - "id": "dedc0f83", + "id": "27a5faa2", "metadata": {}, "outputs": [], "source": [ + "target_class = 0\n", "batch_size = 4\n", "batch = [random_test_mnist[i] for i in range(batch_size)]\n", "x = torch.stack([b[0] for b in batch])\n", "y = torch.tensor([b[1] for b in batch])\n", - "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x_fake = torch.tensor(counterfactuals[target_class, :batch_size])\n", "x = x.to(device).float()\n", "y = y.to(device)\n", "x_fake = x_fake.to(device).float()\n", @@ -1422,7 +1413,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5446e796", + "id": "fe913d7c", "metadata": { "title": "Another visualization function" }, @@ -1442,7 +1433,7 @@ " ax1.imshow(counterfactual_image)\n", " ax1.set_title(\"Counterfactual\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()" @@ -1451,7 +1442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5e2fb59e", + "id": "088a3080", "metadata": { "lines_to_next_cell": 0 }, @@ -1467,7 +1458,7 @@ }, { "cell_type": "markdown", - "id": "b393a8f1", + "id": "4c183323", "metadata": { "lines_to_next_cell": 0 }, @@ -1483,7 +1474,122 @@ }, { "cell_type": "markdown", - "id": "5ba47fc6", + "id": "6e010c04", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n", + "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n", + "Here we have a much simpler task so we have some advantages:\n", + "- The counterfactuals are perfect! They already change the bare minimum (trust me).\n", + "- The changes are not objects, but colors.\n", + "As such, we will do a much simpler linear interpolation between the images." + ] + }, + { + "cell_type": "markdown", + "id": "f83ddb66", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

    Task 4.2: Interpolation

    \n", + "Let's interpolate between the original image and the counterfactual image.\n", + "We will create 10 images in between the two, and classify them.\n", + "
    " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a5b7362", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "num_interpolations = 15\n", + "alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]\n", + "interpolated_images = [\n", + " alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)\n", + "]\n", + "interpolated_images = torch.stack(interpolated_images)\n", + "interpolated_classifications = [\n", + " model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cdd2182", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Plot the results\n", + "idx = 0\n", + "fig, axs = plt.subplots(\n", + " batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)\n", + ")\n", + "for idx in range(batch_size):\n", + " # Plot the original image\n", + " axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, 0].axis(\"off\")\n", + " # Use the class as the title\n", + " axs[idx, 0].set_title(f\"Image: y={y[idx].item()}\")\n", + " # Plot the counterfactual image\n", + " axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, -1].axis(\"off\")\n", + " # Use the target class as the title\n", + " axs[idx, -1].set_title(f\"CF: y={target_class}\")\n", + " for i, ax in enumerate(axs[idx][1:-1]):\n", + " ax.imshow(\n", + " np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))\n", + " )\n", + " ax.axis(\"off\")\n", + " classification = torch.softmax(interpolated_classifications[i][idx], dim=0)\n", + " # Plot the classification as the title in order source classification | target classification\n", + " ax.set_title(\n", + " f\"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "e3b495ef", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Take some time to look at the plot we just made.\n", + "On the very left are the images we randomly chose - it's class is shown in the title.\n", + "On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.\n", + "In between are the interpolated images - their title shows their classification as \"source classification | target classification\".\n", + "This is a lot to take in, so take your time! Once you're ready, we can move on to the questions." + ] + }, + { + "cell_type": "markdown", + "id": "d12b0491", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

    Questions

    \n", + "
      \n", + "
    • Do the images change smoothly from one class to another?
    • \n", + "
    • Can you see any patterns in the changes?
    • \n", + "
    • What happens when the original image and the counterfactual image are of the same class?
    • \n", + "
    • Based on this, would you trust this classifier on unseen images more or less than you did before?
    • \n", + "
    " + ] + }, + { + "cell_type": "markdown", + "id": "4d04321c", "metadata": { "lines_to_next_cell": 0 }, @@ -1493,12 +1599,13 @@ "- Created a StarGAN that can change the class of an image\n", "- Evaluated the StarGAN on unseen data\n", "- Used the StarGAN to create counterfactual images\n", - "- Used the counterfactual images to highlight the differences between classes\n" + "- Used the counterfactual images to highlight the differences between classes\n", + "- Interpolated between the images to see how the classifier behaves\n" ] }, { "cell_type": "markdown", - "id": "2654d788", + "id": "74cf3a6d", "metadata": { "lines_to_next_cell": 0 }, @@ -1521,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "76559366", + "id": "3a2e2d67", "metadata": {}, "source": [ "

    Task 5.1: Explore the style space

    \n", @@ -1533,7 +1640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1fdb890", + "id": "0ce55391", "metadata": {}, "outputs": [], "source": [ @@ -1568,7 +1675,7 @@ }, { "cell_type": "markdown", - "id": "b666769e", + "id": "93beff08", "metadata": { "lines_to_next_cell": 0 }, @@ -1584,7 +1691,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e61d0c9b", + "id": "0f1e9fae", "metadata": { "lines_to_next_cell": 0 }, @@ -1611,7 +1718,7 @@ }, { "cell_type": "markdown", - "id": "6f1d3ff3", + "id": "f2a20f61", "metadata": { "lines_to_next_cell": 0 }, @@ -1625,7 +1732,7 @@ }, { "cell_type": "markdown", - "id": "90889399", + "id": "ec1bb346", "metadata": { "lines_to_next_cell": 0 }, @@ -1642,7 +1749,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f67b3f90", + "id": "92da5e43", "metadata": {}, "outputs": [], "source": [ @@ -1665,7 +1772,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b18b2b81", + "id": "017f3089", "metadata": { "lines_to_next_cell": 0 }, @@ -1674,7 +1781,7 @@ }, { "cell_type": "markdown", - "id": "bf87e80b", + "id": "e4153d66", "metadata": {}, "source": [ "

    Questions

    \n", @@ -1686,7 +1793,7 @@ }, { "cell_type": "markdown", - "id": "11aafcc5", + "id": "f9218632", "metadata": {}, "source": [ "

    Checkpoint 5

    \n", @@ -1701,6 +1808,27 @@ "If you have any questions, feel free to ask them in the chat!\n", "And check the Solutions exercise for a definite answer to how these classes are defined!" ] + }, + { + "cell_type": "markdown", + "id": "7413f5e2", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# Bonus!\n", + "If you have extra time, you can try to break the StarGAN!\n", + "There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?\n", + "Some things you might want to try:\n", + "- What happens if you don't use the EMA model?\n", + "- What happens if you change the learning rates?\n", + "- What happens if you add a Sigmoid activation to the output of the style encoder?\n", + "See what else you can think of, and see how finnicky training a GAN can be!\n", + "\n", + "# %% [markdown] tags=[\"solution\"]\n", + "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n", + "Check your style space again to see if you can see the patterns now!" + ] } ], "metadata": { diff --git a/solution.ipynb b/solution.ipynb index b0b9e5a..ddf44e5 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "30c11df5", + "id": "f929d6ee", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "ec2899d4", + "id": "2def42e4", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "2c084b97", + "id": "e386c146", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d26a8bb", + "id": "ca1ceaeb", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "f8a5937c", + "id": "073875c1", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9c0ce960", + "id": "634ec90e", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "0cb834e5", + "id": "eca78719", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "a32035d7", + "id": "104e3243", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0146821b", + "id": "8df1ad50", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "6ecddeb8", + "id": "bc54ec3d", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c271ecd9", + "id": "5d8324c9", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "46a684f4", + "id": "1970f094", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "0255c073", + "id": "6dae9d0f", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5b162b7", + "id": "cb8da288", "metadata": { "tags": [] }, @@ -233,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "6d418ea1", + "id": "c6ea5e99", "metadata": { "tags": [] }, @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f93e8067", + "id": "029446ea", "metadata": { "tags": [ "solution" @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4ba6b3a", + "id": "db3ecbf5", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "56e432ae", + "id": "a5fc7736", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -298,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9561d46f", + "id": "7c214a9a", "metadata": { "tags": [] }, @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a55fe8ec", + "id": "256ae2c6", "metadata": { "tags": [] }, @@ -339,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "1d8c03a0", + "id": "e477b37b", "metadata": { "lines_to_next_cell": 2 }, @@ -353,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "2a24c70a", + "id": "5660f1bc", "metadata": { "lines_to_next_cell": 0 }, @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6e875faa", + "id": "438c2f08", "metadata": {}, "outputs": [], "source": [ @@ -378,7 +378,7 @@ " ax1.imshow(original_image)\n", " ax1.set_title(\"Image\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()\n", @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "3f73608f", + "id": "9049d759", "metadata": { "lines_to_next_cell": 0 }, @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "a8e71c0b", + "id": "57f455ec", "metadata": {}, "source": [ "\n", @@ -431,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "dbb04b6f", + "id": "3eb81cd4", "metadata": {}, "source": [ "

    Task 2.3: Use random noise as a baseline

    \n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cde2c2ff", + "id": "1f1843ae", "metadata": { "tags": [ "solution" @@ -469,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "bf7e934c", + "id": "5811ca2f", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a0cb195e", + "id": "45789751", "metadata": { "tags": [ "solution" @@ -511,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "db46361b", + "id": "dee36774", "metadata": { "tags": [] }, @@ -527,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "e9105812", + "id": "23f7f7d8", "metadata": {}, "source": [ "

    BONUS Task: Using different attributions.

    \n", @@ -541,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "0b2d0f2f", + "id": "dfdad8b4", "metadata": {}, "source": [ "

    Checkpoint 2

    \n", @@ -561,7 +561,7 @@ }, { "cell_type": "markdown", - "id": "531169e5", + "id": "d47d79a3", "metadata": { "lines_to_next_cell": 0 }, @@ -589,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "331e56d6", + "id": "f4e5602e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -612,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "301ee289", + "id": "0ee0fb95", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "4ce023f6", + "id": "4794a5fc", "metadata": { "lines_to_next_cell": 0 }, @@ -659,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b491022a", + "id": "ab9dd94e", "metadata": { "tags": [ "solution" @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "16f87104", + "id": "f1716b47", "metadata": { "tags": [] }, @@ -691,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "9f1d1149", + "id": "5d54535a", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -708,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71695d57", + "id": "069675ed", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "231a5202", + "id": "1db62898", "metadata": { "lines_to_next_cell": 0 }, @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c0a2d54d", + "id": "c22a89e4", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "4540ef18", + "id": "2fd71dd0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -761,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b9fc6671", + "id": "b7c89896", "metadata": { "lines_to_next_cell": 0 }, @@ -773,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "196daf45", + "id": "38280c63", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -792,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e9ddd12", + "id": "52034753", "metadata": {}, "outputs": [], "source": [ @@ -801,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "eade7df1", + "id": "501b7a7e", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -817,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1deb8b8b", + "id": "c26dc615", "metadata": {}, "outputs": [], "source": [ @@ -826,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "ba4a7f7f", + "id": "85a5ffc4", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b5b3d5dc", + "id": "8822bdf7", "metadata": {}, "outputs": [], "source": [ @@ -851,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "a029e923", + "id": "49dff835", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -865,7 +865,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54b4de87", + "id": "bef2d8e0", "metadata": {}, "outputs": [], "source": [ @@ -877,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "014e484e", + "id": "6cd15802", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -897,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6344c83", + "id": "364c8d90", "metadata": {}, "outputs": [], "source": [ @@ -921,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "08b7b3af", + "id": "b0c26036", "metadata": {}, "outputs": [], "source": [ @@ -931,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "23fbf680", + "id": "80c55b83", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -953,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "9cb8281d", + "id": "d6935683", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -965,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "699b3220", + "id": "001e8548", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "4c25819b", + "id": "580be770", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1047,7 +1047,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0d64d32d", + "id": "ff582fa2", "metadata": {}, "outputs": [], "source": [ @@ -1063,7 +1063,7 @@ }, { "cell_type": "markdown", - "id": "326ba2b5", + "id": "6f66943c", "metadata": { "tags": [] }, @@ -1078,7 +1078,7 @@ }, { "cell_type": "markdown", - "id": "3e58ca01", + "id": "b75e3ea6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1090,7 +1090,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c522efa", + "id": "e4661010", "metadata": {}, "outputs": [], "source": [ @@ -1110,19 +1110,9 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "30b6dac9", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", - "id": "a3ecbc7b", + "id": "bd665cb6", "metadata": { "tags": [] }, @@ -1138,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "e6bdaecb", + "id": "657df1b2", "metadata": { "tags": [] }, @@ -1148,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "7f994579", + "id": "62ee3d1f", "metadata": { "tags": [] }, @@ -1165,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4e4fe83e", + "id": "340f5689", "metadata": { "title": "Loading the test dataset" }, @@ -1185,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "049a6b22", + "id": "fac75d37", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1197,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "639f37e2", + "id": "9a901d76", "metadata": {}, "outputs": [], "source": [ @@ -1210,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "02cb705b", + "id": "194b5756", "metadata": { "lines_to_next_cell": 0 }, @@ -1220,12 +1210,12 @@ }, { "cell_type": "markdown", - "id": "f41a6ce5", + "id": "c6b2d5d5", "metadata": { "lines_to_next_cell": 0 }, "source": [ - "

    Task 4: Create counterfactuals

    \n", + "

    Task 4.1: Create counterfactuals

    \n", "In the below, we will store the counterfactual images in the `counterfactuals` array.\n", "\n", "
      \n", @@ -1238,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00616e67", + "id": "37010964", "metadata": { "tags": [ "solution" @@ -1275,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "ebffc15f", + "id": "4c1cc8e2", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1287,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baac8071", + "id": "d2f3e241", "metadata": {}, "outputs": [], "source": [ @@ -1300,7 +1290,7 @@ }, { "cell_type": "markdown", - "id": "88e7ea0c", + "id": "f085ddbd", "metadata": { "tags": [] }, @@ -1315,7 +1305,7 @@ }, { "cell_type": "markdown", - "id": "25972c49", + "id": "3713cc97", "metadata": { "tags": [] }, @@ -1326,7 +1316,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12d49576", + "id": "d4d09ab3", "metadata": {}, "outputs": [], "source": [ @@ -1340,7 +1330,7 @@ }, { "cell_type": "markdown", - "id": "8e6f04f3", + "id": "d8071f41", "metadata": { "tags": [] }, @@ -1355,7 +1345,7 @@ }, { "cell_type": "markdown", - "id": "50728ff2", + "id": "e8a1bdd8", "metadata": { "lines_to_next_cell": 0 }, @@ -1370,15 +1360,16 @@ { "cell_type": "code", "execution_count": null, - "id": "dedc0f83", + "id": "27a5faa2", "metadata": {}, "outputs": [], "source": [ + "target_class = 0\n", "batch_size = 4\n", "batch = [random_test_mnist[i] for i in range(batch_size)]\n", "x = torch.stack([b[0] for b in batch])\n", "y = torch.tensor([b[1] for b in batch])\n", - "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n", + "x_fake = torch.tensor(counterfactuals[target_class, :batch_size])\n", "x = x.to(device).float()\n", "y = y.to(device)\n", "x_fake = x_fake.to(device).float()\n", @@ -1390,7 +1381,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5446e796", + "id": "fe913d7c", "metadata": { "title": "Another visualization function" }, @@ -1410,7 +1401,7 @@ " ax1.imshow(counterfactual_image)\n", " ax1.set_title(\"Counterfactual\")\n", " ax1.axis(\"off\")\n", - " ax2.imshow(np.abs(attribution))\n", + " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n", " ax2.set_title(\"Attribution\")\n", " ax2.axis(\"off\")\n", " plt.show()" @@ -1419,7 +1410,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5e2fb59e", + "id": "088a3080", "metadata": { "lines_to_next_cell": 0 }, @@ -1435,7 +1426,7 @@ }, { "cell_type": "markdown", - "id": "b393a8f1", + "id": "4c183323", "metadata": { "lines_to_next_cell": 0 }, @@ -1451,7 +1442,122 @@ }, { "cell_type": "markdown", - "id": "5ba47fc6", + "id": "6e010c04", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n", + "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n", + "Here we have a much simpler task so we have some advantages:\n", + "- The counterfactuals are perfect! They already change the bare minimum (trust me).\n", + "- The changes are not objects, but colors.\n", + "As such, we will do a much simpler linear interpolation between the images." + ] + }, + { + "cell_type": "markdown", + "id": "f83ddb66", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

      Task 4.2: Interpolation

      \n", + "Let's interpolate between the original image and the counterfactual image.\n", + "We will create 10 images in between the two, and classify them.\n", + "
      " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a5b7362", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "num_interpolations = 15\n", + "alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]\n", + "interpolated_images = [\n", + " alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)\n", + "]\n", + "interpolated_images = torch.stack(interpolated_images)\n", + "interpolated_classifications = [\n", + " model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cdd2182", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Plot the results\n", + "idx = 0\n", + "fig, axs = plt.subplots(\n", + " batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)\n", + ")\n", + "for idx in range(batch_size):\n", + " # Plot the original image\n", + " axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, 0].axis(\"off\")\n", + " # Use the class as the title\n", + " axs[idx, 0].set_title(f\"Image: y={y[idx].item()}\")\n", + " # Plot the counterfactual image\n", + " axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n", + " axs[idx, -1].axis(\"off\")\n", + " # Use the target class as the title\n", + " axs[idx, -1].set_title(f\"CF: y={target_class}\")\n", + " for i, ax in enumerate(axs[idx][1:-1]):\n", + " ax.imshow(\n", + " np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))\n", + " )\n", + " ax.axis(\"off\")\n", + " classification = torch.softmax(interpolated_classifications[i][idx], dim=0)\n", + " # Plot the classification as the title in order source classification | target classification\n", + " ax.set_title(\n", + " f\"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "e3b495ef", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Take some time to look at the plot we just made.\n", + "On the very left are the images we randomly chose - it's class is shown in the title.\n", + "On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.\n", + "In between are the interpolated images - their title shows their classification as \"source classification | target classification\".\n", + "This is a lot to take in, so take your time! Once you're ready, we can move on to the questions." + ] + }, + { + "cell_type": "markdown", + "id": "d12b0491", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "

      Questions

      \n", + "
        \n", + "
      • Do the images change smoothly from one class to another?
      • \n", + "
      • Can you see any patterns in the changes?
      • \n", + "
      • What happens when the original image and the counterfactual image are of the same class?
      • \n", + "
      • Based on this, would you trust this classifier on unseen images more or less than you did before?
      • \n", + "
      " + ] + }, + { + "cell_type": "markdown", + "id": "4d04321c", "metadata": { "lines_to_next_cell": 0 }, @@ -1461,12 +1567,13 @@ "- Created a StarGAN that can change the class of an image\n", "- Evaluated the StarGAN on unseen data\n", "- Used the StarGAN to create counterfactual images\n", - "- Used the counterfactual images to highlight the differences between classes\n" + "- Used the counterfactual images to highlight the differences between classes\n", + "- Interpolated between the images to see how the classifier behaves\n" ] }, { "cell_type": "markdown", - "id": "2654d788", + "id": "74cf3a6d", "metadata": { "lines_to_next_cell": 0 }, @@ -1489,7 +1596,7 @@ }, { "cell_type": "markdown", - "id": "76559366", + "id": "3a2e2d67", "metadata": {}, "source": [ "

      Task 5.1: Explore the style space

      \n", @@ -1501,7 +1608,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1fdb890", + "id": "0ce55391", "metadata": {}, "outputs": [], "source": [ @@ -1536,7 +1643,7 @@ }, { "cell_type": "markdown", - "id": "b666769e", + "id": "93beff08", "metadata": { "lines_to_next_cell": 0 }, @@ -1552,7 +1659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e61d0c9b", + "id": "0f1e9fae", "metadata": { "lines_to_next_cell": 0 }, @@ -1579,7 +1686,7 @@ }, { "cell_type": "markdown", - "id": "6f1d3ff3", + "id": "f2a20f61", "metadata": { "lines_to_next_cell": 0 }, @@ -1593,7 +1700,7 @@ }, { "cell_type": "markdown", - "id": "90889399", + "id": "ec1bb346", "metadata": { "lines_to_next_cell": 0 }, @@ -1610,7 +1717,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f67b3f90", + "id": "92da5e43", "metadata": {}, "outputs": [], "source": [ @@ -1633,7 +1740,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b18b2b81", + "id": "017f3089", "metadata": { "lines_to_next_cell": 0 }, @@ -1642,7 +1749,7 @@ }, { "cell_type": "markdown", - "id": "bf87e80b", + "id": "e4153d66", "metadata": {}, "source": [ "

      Questions

      \n", @@ -1654,7 +1761,7 @@ }, { "cell_type": "markdown", - "id": "11aafcc5", + "id": "f9218632", "metadata": {}, "source": [ "

      Checkpoint 5

      \n", @@ -1672,14 +1779,21 @@ }, { "cell_type": "markdown", - "id": "a5c8b45e", + "id": "7413f5e2", "metadata": { - "lines_to_next_cell": 0, - "tags": [ - "solution" - ] + "lines_to_next_cell": 0 }, "source": [ + "# Bonus!\n", + "If you have extra time, you can try to break the StarGAN!\n", + "There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?\n", + "Some things you might want to try:\n", + "- What happens if you don't use the EMA model?\n", + "- What happens if you change the learning rates?\n", + "- What happens if you add a Sigmoid activation to the output of the style encoder?\n", + "See what else you can think of, and see how finnicky training a GAN can be!\n", + "\n", + "# %% [markdown] tags=[\"solution\"]\n", "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n", "Check your style space again to see if you can see the patterns now!" ] @@ -1687,7 +1801,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45e17541", + "id": "97bc18ec", "metadata": { "tags": [ "solution" From 99c6708473aa9d36df518f1efe787e54f3386a27 Mon Sep 17 00:00:00 2001 From: adjavon Date: Wed, 21 Aug 2024 18:48:26 +0000 Subject: [PATCH 5/5] Commit from GitHub Actions (Build Notebooks) --- exercise.ipynb | 204 ++++++++++++++++++++++++------------------------ solution.ipynb | 206 ++++++++++++++++++++++++------------------------- 2 files changed, 205 insertions(+), 205 deletions(-) diff --git a/exercise.ipynb b/exercise.ipynb index 5fbc67f..557e3a5 100644 --- a/exercise.ipynb +++ b/exercise.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "f929d6ee", + "id": "d5acca24", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "2def42e4", + "id": "cb662602", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "e386c146", + "id": "7af25d33", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca1ceaeb", + "id": "f2d6db31", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "073875c1", + "id": "3ef3bfa9", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "634ec90e", + "id": "02be6af2", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "eca78719", + "id": "6a76f50d", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "104e3243", + "id": "fae7773e", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b9b9220", + "id": "01199870", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -155,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "bc54ec3d", + "id": "137cd457", "metadata": { "lines_to_next_cell": 0 }, @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d8324c9", + "id": "3c09237d", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "1970f094", + "id": "92af9975", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "6dae9d0f", + "id": "c6d16b90", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -216,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb8da288", + "id": "29c8613f", "metadata": { "tags": [] }, @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "c6ea5e99", + "id": "ba7abfbe", "metadata": { "tags": [] }, @@ -250,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79098ac5", + "id": "a842f6b9", "metadata": { "tags": [ "task" @@ -271,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "db3ecbf5", + "id": "d4148689", "metadata": { "tags": [] }, @@ -284,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "a5fc7736", + "id": "276c658d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -296,7 +296,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c214a9a", + "id": "30d569fb", "metadata": { "tags": [] }, @@ -324,7 +324,7 @@ { "cell_type": "code", "execution_count": null, - "id": "256ae2c6", + "id": "570a7ccd", "metadata": { "tags": [] }, @@ -337,7 +337,7 @@ }, { "cell_type": "markdown", - "id": "e477b37b", + "id": "4f13b2f4", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "5660f1bc", + "id": "8c8790d3", "metadata": { "lines_to_next_cell": 0 }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "438c2f08", + "id": "ab2e8a0d", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ }, { "cell_type": "markdown", - "id": "9049d759", + "id": "b031ab39", "metadata": { "lines_to_next_cell": 0 }, @@ -403,7 +403,7 @@ }, { "cell_type": "markdown", - "id": "57f455ec", + "id": "e0a15147", "metadata": {}, "source": [ "\n", @@ -429,7 +429,7 @@ }, { "cell_type": "markdown", - "id": "3eb81cd4", + "id": "10e59bd9", "metadata": {}, "source": [ "

      Task 2.3: Use random noise as a baseline

      \n", @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5c64ea11", + "id": "939622d5", "metadata": { "tags": [ "task" @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "5811ca2f", + "id": "82e2a1b3", "metadata": { "tags": [] }, @@ -476,7 +476,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0a216f9", + "id": "8fae20a2", "metadata": { "tags": [ "task" @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "dee36774", + "id": "446d0aab", "metadata": { "tags": [] }, @@ -515,7 +515,7 @@ }, { "cell_type": "markdown", - "id": "23f7f7d8", + "id": "c0107144", "metadata": {}, "source": [ "

      BONUS Task: Using different attributions.

      \n", @@ -529,7 +529,7 @@ }, { "cell_type": "markdown", - "id": "dfdad8b4", + "id": "9a69110b", "metadata": {}, "source": [ "

      Checkpoint 2

      \n", @@ -549,7 +549,7 @@ }, { "cell_type": "markdown", - "id": "d47d79a3", + "id": "d6cd057a", "metadata": { "lines_to_next_cell": 0 }, @@ -577,7 +577,7 @@ }, { "cell_type": "markdown", - "id": "f4e5602e", + "id": "003cd7d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -600,7 +600,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ee0fb95", + "id": "c99ee37d", "metadata": {}, "outputs": [], "source": [ @@ -632,7 +632,7 @@ }, { "cell_type": "markdown", - "id": "4794a5fc", + "id": "9a1ba5d0", "metadata": { "lines_to_next_cell": 0 }, @@ -647,7 +647,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a111d5e8", + "id": "0e255fc6", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -668,7 +668,7 @@ }, { "cell_type": "markdown", - "id": "f1716b47", + "id": "d0d5441a", "metadata": { "tags": [] }, @@ -683,7 +683,7 @@ }, { "cell_type": "markdown", - "id": "5d54535a", + "id": "1a5da3b0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -700,7 +700,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39f5db85", + "id": "a17d88eb", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -714,7 +714,7 @@ }, { "cell_type": "markdown", - "id": "1db62898", + "id": "6152ae96", "metadata": { "lines_to_next_cell": 0 }, @@ -725,7 +725,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c22a89e4", + "id": "47ccc999", "metadata": {}, "outputs": [], "source": [ @@ -735,7 +735,7 @@ }, { "cell_type": "markdown", - "id": "2fd71dd0", + "id": "235ecf34", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b7c89896", + "id": "ecd33083", "metadata": { "lines_to_next_cell": 0 }, @@ -765,7 +765,7 @@ }, { "cell_type": "markdown", - "id": "38280c63", + "id": "46499a30", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -784,7 +784,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52034753", + "id": "a75d1d8e", "metadata": {}, "outputs": [], "source": [ @@ -793,7 +793,7 @@ }, { "cell_type": "markdown", - "id": "501b7a7e", + "id": "6f1e7957", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -809,7 +809,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c26dc615", + "id": "48c11783", "metadata": {}, "outputs": [], "source": [ @@ -818,7 +818,7 @@ }, { "cell_type": "markdown", - "id": "85a5ffc4", + "id": "7a403168", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -830,7 +830,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8822bdf7", + "id": "197f0b55", "metadata": {}, "outputs": [], "source": [ @@ -843,7 +843,7 @@ }, { "cell_type": "markdown", - "id": "49dff835", + "id": "14642c7c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -857,7 +857,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bef2d8e0", + "id": "8567a4c7", "metadata": {}, "outputs": [], "source": [ @@ -869,7 +869,7 @@ }, { "cell_type": "markdown", - "id": "6cd15802", + "id": "d617a57b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -889,7 +889,7 @@ { "cell_type": "code", "execution_count": null, - "id": "364c8d90", + "id": "152d1fc2", "metadata": {}, "outputs": [], "source": [ @@ -913,7 +913,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b0c26036", + "id": "1beae4d6", "metadata": {}, "outputs": [], "source": [ @@ -923,7 +923,7 @@ }, { "cell_type": "markdown", - "id": "80c55b83", + "id": "6fba9909", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -945,7 +945,7 @@ }, { "cell_type": "markdown", - "id": "d6935683", + "id": "07dd61b7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -957,7 +957,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5a11de5e", + "id": "df4f3916", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1068,7 +1068,7 @@ }, { "cell_type": "markdown", - "id": "580be770", + "id": "0574b61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1080,7 +1080,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff582fa2", + "id": "94723f87", "metadata": {}, "outputs": [], "source": [ @@ -1096,7 +1096,7 @@ }, { "cell_type": "markdown", - "id": "6f66943c", + "id": "a734af98", "metadata": { "tags": [] }, @@ -1111,7 +1111,7 @@ }, { "cell_type": "markdown", - "id": "b75e3ea6", + "id": "34e2b5cd", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1123,7 +1123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4661010", + "id": "22f64934", "metadata": {}, "outputs": [], "source": [ @@ -1145,7 +1145,7 @@ }, { "cell_type": "markdown", - "id": "bd665cb6", + "id": "d8dfbe12", "metadata": { "tags": [] }, @@ -1161,7 +1161,7 @@ }, { "cell_type": "markdown", - "id": "657df1b2", + "id": "32991250", "metadata": { "tags": [] }, @@ -1171,7 +1171,7 @@ }, { "cell_type": "markdown", - "id": "62ee3d1f", + "id": "216383e1", "metadata": { "tags": [] }, @@ -1188,7 +1188,7 @@ { "cell_type": "code", "execution_count": null, - "id": "340f5689", + "id": "54fd4016", "metadata": { "title": "Loading the test dataset" }, @@ -1208,7 +1208,7 @@ }, { "cell_type": "markdown", - "id": "fac75d37", + "id": "1dbcc267", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1220,7 +1220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a901d76", + "id": "04b20178", "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "markdown", - "id": "194b5756", + "id": "ea19f383", "metadata": { "lines_to_next_cell": 0 }, @@ -1243,7 +1243,7 @@ }, { "cell_type": "markdown", - "id": "c6b2d5d5", + "id": "e10ea481", "metadata": { "lines_to_next_cell": 0 }, @@ -1261,7 +1261,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bbe08ad9", + "id": "24f75a7b", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -1297,7 +1297,7 @@ }, { "cell_type": "markdown", - "id": "4c1cc8e2", + "id": "44a96f81", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1309,7 +1309,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2f3e241", + "id": "83523c2f", "metadata": {}, "outputs": [], "source": [ @@ -1322,7 +1322,7 @@ }, { "cell_type": "markdown", - "id": "f085ddbd", + "id": "bee1aa55", "metadata": { "tags": [] }, @@ -1337,7 +1337,7 @@ }, { "cell_type": "markdown", - "id": "3713cc97", + "id": "dbac0e2e", "metadata": { "tags": [] }, @@ -1348,7 +1348,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4d09ab3", + "id": "7d263995", "metadata": {}, "outputs": [], "source": [ @@ -1362,7 +1362,7 @@ }, { "cell_type": "markdown", - "id": "d8071f41", + "id": "60081f81", "metadata": { "tags": [] }, @@ -1377,7 +1377,7 @@ }, { "cell_type": "markdown", - "id": "e8a1bdd8", + "id": "a037d368", "metadata": { "lines_to_next_cell": 0 }, @@ -1392,7 +1392,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27a5faa2", + "id": "308d7f21", "metadata": {}, "outputs": [], "source": [ @@ -1413,7 +1413,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe913d7c", + "id": "12c79b2d", "metadata": { "title": "Another visualization function" }, @@ -1442,7 +1442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "088a3080", + "id": "aa24deb7", "metadata": { "lines_to_next_cell": 0 }, @@ -1458,7 +1458,7 @@ }, { "cell_type": "markdown", - "id": "4c183323", + "id": "1726fcb6", "metadata": { "lines_to_next_cell": 0 }, @@ -1474,7 +1474,7 @@ }, { "cell_type": "markdown", - "id": "6e010c04", + "id": "93aff28e", "metadata": { "lines_to_next_cell": 0 }, @@ -1489,7 +1489,7 @@ }, { "cell_type": "markdown", - "id": "f83ddb66", + "id": "0389b58c", "metadata": { "lines_to_next_cell": 0 }, @@ -1503,7 +1503,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a5b7362", + "id": "61b9f9a0", "metadata": { "lines_to_next_cell": 0 }, @@ -1523,7 +1523,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cdd2182", + "id": "7e70b512", "metadata": { "lines_to_next_cell": 0 }, @@ -1559,7 +1559,7 @@ }, { "cell_type": "markdown", - "id": "e3b495ef", + "id": "e9126f01", "metadata": { "lines_to_next_cell": 0 }, @@ -1573,7 +1573,7 @@ }, { "cell_type": "markdown", - "id": "d12b0491", + "id": "d50ea18c", "metadata": { "lines_to_next_cell": 0 }, @@ -1589,7 +1589,7 @@ }, { "cell_type": "markdown", - "id": "4d04321c", + "id": "a93a5437", "metadata": { "lines_to_next_cell": 0 }, @@ -1605,7 +1605,7 @@ }, { "cell_type": "markdown", - "id": "74cf3a6d", + "id": "b8f44b99", "metadata": { "lines_to_next_cell": 0 }, @@ -1628,7 +1628,7 @@ }, { "cell_type": "markdown", - "id": "3a2e2d67", + "id": "aafe8844", "metadata": {}, "source": [ "

      Task 5.1: Explore the style space

      \n", @@ -1640,7 +1640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ce55391", + "id": "0128d7bd", "metadata": {}, "outputs": [], "source": [ @@ -1675,7 +1675,7 @@ }, { "cell_type": "markdown", - "id": "93beff08", + "id": "609028bd", "metadata": { "lines_to_next_cell": 0 }, @@ -1691,7 +1691,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f1e9fae", + "id": "8b7185c5", "metadata": { "lines_to_next_cell": 0 }, @@ -1718,7 +1718,7 @@ }, { "cell_type": "markdown", - "id": "f2a20f61", + "id": "772e4231", "metadata": { "lines_to_next_cell": 0 }, @@ -1732,7 +1732,7 @@ }, { "cell_type": "markdown", - "id": "ec1bb346", + "id": "7d159c2b", "metadata": { "lines_to_next_cell": 0 }, @@ -1749,7 +1749,7 @@ { "cell_type": "code", "execution_count": null, - "id": "92da5e43", + "id": "13eb454f", "metadata": {}, "outputs": [], "source": [ @@ -1772,7 +1772,7 @@ { "cell_type": "code", "execution_count": null, - "id": "017f3089", + "id": "3ddc6a07", "metadata": { "lines_to_next_cell": 0 }, @@ -1781,7 +1781,7 @@ }, { "cell_type": "markdown", - "id": "e4153d66", + "id": "303338b1", "metadata": {}, "source": [ "

      Questions

      \n", @@ -1793,7 +1793,7 @@ }, { "cell_type": "markdown", - "id": "f9218632", + "id": "e18da1e0", "metadata": {}, "source": [ "

      Checkpoint 5

      \n", @@ -1811,7 +1811,7 @@ }, { "cell_type": "markdown", - "id": "7413f5e2", + "id": "b74f495b", "metadata": { "lines_to_next_cell": 0 }, diff --git a/solution.ipynb b/solution.ipynb index ddf44e5..dcdff0d 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "f929d6ee", + "id": "d5acca24", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -29,7 +29,7 @@ }, { "cell_type": "markdown", - "id": "2def42e4", + "id": "cb662602", "metadata": { "lines_to_next_cell": 0 }, @@ -41,7 +41,7 @@ }, { "cell_type": "markdown", - "id": "e386c146", + "id": "7af25d33", "metadata": {}, "source": [ "\n", @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ca1ceaeb", + "id": "f2d6db31", "metadata": { "lines_to_next_cell": 0 }, @@ -68,7 +68,7 @@ }, { "cell_type": "markdown", - "id": "073875c1", + "id": "3ef3bfa9", "metadata": { "lines_to_next_cell": 0 }, @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "634ec90e", + "id": "02be6af2", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "eca78719", + "id": "6a76f50d", "metadata": { "lines_to_next_cell": 0 }, @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "104e3243", + "id": "fae7773e", "metadata": { "lines_to_next_cell": 0 }, @@ -130,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8df1ad50", + "id": "d0b3ad78", "metadata": { "tags": [ "solution" @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "bc54ec3d", + "id": "137cd457", "metadata": { "lines_to_next_cell": 0 }, @@ -165,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d8324c9", + "id": "3c09237d", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "1970f094", + "id": "92af9975", "metadata": {}, "source": [ "# Part 2: Using Integrated Gradients to find what the classifier knows\n", @@ -202,7 +202,7 @@ }, { "cell_type": "markdown", - "id": "6dae9d0f", + "id": "c6d16b90", "metadata": {}, "source": [ "## Attributions through integrated gradients\n", @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cb8da288", + "id": "29c8613f", "metadata": { "tags": [] }, @@ -233,7 +233,7 @@ }, { "cell_type": "markdown", - "id": "c6ea5e99", + "id": "ba7abfbe", "metadata": { "tags": [] }, @@ -249,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "029446ea", + "id": "c4e89e02", "metadata": { "tags": [ "solution" @@ -273,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "db3ecbf5", + "id": "d4148689", "metadata": { "tags": [] }, @@ -286,7 +286,7 @@ }, { "cell_type": "markdown", - "id": "a5fc7736", + "id": "276c658d", "metadata": { "lines_to_next_cell": 2, "tags": [] @@ -298,7 +298,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c214a9a", + "id": "30d569fb", "metadata": { "tags": [] }, @@ -326,7 +326,7 @@ { "cell_type": "code", "execution_count": null, - "id": "256ae2c6", + "id": "570a7ccd", "metadata": { "tags": [] }, @@ -339,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "e477b37b", + "id": "4f13b2f4", "metadata": { "lines_to_next_cell": 2 }, @@ -353,7 +353,7 @@ }, { "cell_type": "markdown", - "id": "5660f1bc", + "id": "8c8790d3", "metadata": { "lines_to_next_cell": 0 }, @@ -366,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "438c2f08", + "id": "ab2e8a0d", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "9049d759", + "id": "b031ab39", "metadata": { "lines_to_next_cell": 0 }, @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "57f455ec", + "id": "e0a15147", "metadata": {}, "source": [ "\n", @@ -431,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "3eb81cd4", + "id": "10e59bd9", "metadata": {}, "source": [ "

      Task 2.3: Use random noise as a baseline

      \n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f1843ae", + "id": "57012d02", "metadata": { "tags": [ "solution" @@ -469,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "5811ca2f", + "id": "82e2a1b3", "metadata": { "tags": [] }, @@ -483,7 +483,7 @@ { "cell_type": "code", "execution_count": null, - "id": "45789751", + "id": "97f80e4a", "metadata": { "tags": [ "solution" @@ -511,7 +511,7 @@ }, { "cell_type": "markdown", - "id": "dee36774", + "id": "446d0aab", "metadata": { "tags": [] }, @@ -527,7 +527,7 @@ }, { "cell_type": "markdown", - "id": "23f7f7d8", + "id": "c0107144", "metadata": {}, "source": [ "

      BONUS Task: Using different attributions.

      \n", @@ -541,7 +541,7 @@ }, { "cell_type": "markdown", - "id": "dfdad8b4", + "id": "9a69110b", "metadata": {}, "source": [ "

      Checkpoint 2

      \n", @@ -561,7 +561,7 @@ }, { "cell_type": "markdown", - "id": "d47d79a3", + "id": "d6cd057a", "metadata": { "lines_to_next_cell": 0 }, @@ -589,7 +589,7 @@ }, { "cell_type": "markdown", - "id": "f4e5602e", + "id": "003cd7d6", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -612,7 +612,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ee0fb95", + "id": "c99ee37d", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "4794a5fc", + "id": "9a1ba5d0", "metadata": { "lines_to_next_cell": 0 }, @@ -659,7 +659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab9dd94e", + "id": "23a3a52c", "metadata": { "tags": [ "solution" @@ -676,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "f1716b47", + "id": "d0d5441a", "metadata": { "tags": [] }, @@ -691,7 +691,7 @@ }, { "cell_type": "markdown", - "id": "5d54535a", + "id": "1a5da3b0", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -708,7 +708,7 @@ { "cell_type": "code", "execution_count": null, - "id": "069675ed", + "id": "7488ac02", "metadata": { "lines_to_next_cell": 0, "tags": [ @@ -722,7 +722,7 @@ }, { "cell_type": "markdown", - "id": "1db62898", + "id": "6152ae96", "metadata": { "lines_to_next_cell": 0 }, @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c22a89e4", + "id": "47ccc999", "metadata": {}, "outputs": [], "source": [ @@ -743,7 +743,7 @@ }, { "cell_type": "markdown", - "id": "2fd71dd0", + "id": "235ecf34", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -761,7 +761,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b7c89896", + "id": "ecd33083", "metadata": { "lines_to_next_cell": 0 }, @@ -773,7 +773,7 @@ }, { "cell_type": "markdown", - "id": "38280c63", + "id": "46499a30", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -792,7 +792,7 @@ { "cell_type": "code", "execution_count": null, - "id": "52034753", + "id": "a75d1d8e", "metadata": {}, "outputs": [], "source": [ @@ -801,7 +801,7 @@ }, { "cell_type": "markdown", - "id": "501b7a7e", + "id": "6f1e7957", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -817,7 +817,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c26dc615", + "id": "48c11783", "metadata": {}, "outputs": [], "source": [ @@ -826,7 +826,7 @@ }, { "cell_type": "markdown", - "id": "85a5ffc4", + "id": "7a403168", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -838,7 +838,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8822bdf7", + "id": "197f0b55", "metadata": {}, "outputs": [], "source": [ @@ -851,7 +851,7 @@ }, { "cell_type": "markdown", - "id": "49dff835", + "id": "14642c7c", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -865,7 +865,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bef2d8e0", + "id": "8567a4c7", "metadata": {}, "outputs": [], "source": [ @@ -877,7 +877,7 @@ }, { "cell_type": "markdown", - "id": "6cd15802", + "id": "d617a57b", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -897,7 +897,7 @@ { "cell_type": "code", "execution_count": null, - "id": "364c8d90", + "id": "152d1fc2", "metadata": {}, "outputs": [], "source": [ @@ -921,7 +921,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b0c26036", + "id": "1beae4d6", "metadata": {}, "outputs": [], "source": [ @@ -931,7 +931,7 @@ }, { "cell_type": "markdown", - "id": "80c55b83", + "id": "6fba9909", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -953,7 +953,7 @@ }, { "cell_type": "markdown", - "id": "d6935683", + "id": "07dd61b7", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -965,7 +965,7 @@ { "cell_type": "code", "execution_count": null, - "id": "001e8548", + "id": "5e54a534", "metadata": { "lines_to_next_cell": 2, "tags": [ @@ -1035,7 +1035,7 @@ }, { "cell_type": "markdown", - "id": "580be770", + "id": "0574b61f", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1047,7 +1047,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ff582fa2", + "id": "94723f87", "metadata": {}, "outputs": [], "source": [ @@ -1063,7 +1063,7 @@ }, { "cell_type": "markdown", - "id": "6f66943c", + "id": "a734af98", "metadata": { "tags": [] }, @@ -1078,7 +1078,7 @@ }, { "cell_type": "markdown", - "id": "b75e3ea6", + "id": "34e2b5cd", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1090,7 +1090,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4661010", + "id": "22f64934", "metadata": {}, "outputs": [], "source": [ @@ -1112,7 +1112,7 @@ }, { "cell_type": "markdown", - "id": "bd665cb6", + "id": "d8dfbe12", "metadata": { "tags": [] }, @@ -1128,7 +1128,7 @@ }, { "cell_type": "markdown", - "id": "657df1b2", + "id": "32991250", "metadata": { "tags": [] }, @@ -1138,7 +1138,7 @@ }, { "cell_type": "markdown", - "id": "62ee3d1f", + "id": "216383e1", "metadata": { "tags": [] }, @@ -1155,7 +1155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "340f5689", + "id": "54fd4016", "metadata": { "title": "Loading the test dataset" }, @@ -1175,7 +1175,7 @@ }, { "cell_type": "markdown", - "id": "fac75d37", + "id": "1dbcc267", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1187,7 +1187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a901d76", + "id": "04b20178", "metadata": {}, "outputs": [], "source": [ @@ -1200,7 +1200,7 @@ }, { "cell_type": "markdown", - "id": "194b5756", + "id": "ea19f383", "metadata": { "lines_to_next_cell": 0 }, @@ -1210,7 +1210,7 @@ }, { "cell_type": "markdown", - "id": "c6b2d5d5", + "id": "e10ea481", "metadata": { "lines_to_next_cell": 0 }, @@ -1228,7 +1228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37010964", + "id": "780e6fdd", "metadata": { "tags": [ "solution" @@ -1265,7 +1265,7 @@ }, { "cell_type": "markdown", - "id": "4c1cc8e2", + "id": "44a96f81", "metadata": { "lines_to_next_cell": 0, "tags": [] @@ -1277,7 +1277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2f3e241", + "id": "83523c2f", "metadata": {}, "outputs": [], "source": [ @@ -1290,7 +1290,7 @@ }, { "cell_type": "markdown", - "id": "f085ddbd", + "id": "bee1aa55", "metadata": { "tags": [] }, @@ -1305,7 +1305,7 @@ }, { "cell_type": "markdown", - "id": "3713cc97", + "id": "dbac0e2e", "metadata": { "tags": [] }, @@ -1316,7 +1316,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4d09ab3", + "id": "7d263995", "metadata": {}, "outputs": [], "source": [ @@ -1330,7 +1330,7 @@ }, { "cell_type": "markdown", - "id": "d8071f41", + "id": "60081f81", "metadata": { "tags": [] }, @@ -1345,7 +1345,7 @@ }, { "cell_type": "markdown", - "id": "e8a1bdd8", + "id": "a037d368", "metadata": { "lines_to_next_cell": 0 }, @@ -1360,7 +1360,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27a5faa2", + "id": "308d7f21", "metadata": {}, "outputs": [], "source": [ @@ -1381,7 +1381,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe913d7c", + "id": "12c79b2d", "metadata": { "title": "Another visualization function" }, @@ -1410,7 +1410,7 @@ { "cell_type": "code", "execution_count": null, - "id": "088a3080", + "id": "aa24deb7", "metadata": { "lines_to_next_cell": 0 }, @@ -1426,7 +1426,7 @@ }, { "cell_type": "markdown", - "id": "4c183323", + "id": "1726fcb6", "metadata": { "lines_to_next_cell": 0 }, @@ -1442,7 +1442,7 @@ }, { "cell_type": "markdown", - "id": "6e010c04", + "id": "93aff28e", "metadata": { "lines_to_next_cell": 0 }, @@ -1457,7 +1457,7 @@ }, { "cell_type": "markdown", - "id": "f83ddb66", + "id": "0389b58c", "metadata": { "lines_to_next_cell": 0 }, @@ -1471,7 +1471,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a5b7362", + "id": "61b9f9a0", "metadata": { "lines_to_next_cell": 0 }, @@ -1491,7 +1491,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cdd2182", + "id": "7e70b512", "metadata": { "lines_to_next_cell": 0 }, @@ -1527,7 +1527,7 @@ }, { "cell_type": "markdown", - "id": "e3b495ef", + "id": "e9126f01", "metadata": { "lines_to_next_cell": 0 }, @@ -1541,7 +1541,7 @@ }, { "cell_type": "markdown", - "id": "d12b0491", + "id": "d50ea18c", "metadata": { "lines_to_next_cell": 0 }, @@ -1557,7 +1557,7 @@ }, { "cell_type": "markdown", - "id": "4d04321c", + "id": "a93a5437", "metadata": { "lines_to_next_cell": 0 }, @@ -1573,7 +1573,7 @@ }, { "cell_type": "markdown", - "id": "74cf3a6d", + "id": "b8f44b99", "metadata": { "lines_to_next_cell": 0 }, @@ -1596,7 +1596,7 @@ }, { "cell_type": "markdown", - "id": "3a2e2d67", + "id": "aafe8844", "metadata": {}, "source": [ "

      Task 5.1: Explore the style space

      \n", @@ -1608,7 +1608,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0ce55391", + "id": "0128d7bd", "metadata": {}, "outputs": [], "source": [ @@ -1643,7 +1643,7 @@ }, { "cell_type": "markdown", - "id": "93beff08", + "id": "609028bd", "metadata": { "lines_to_next_cell": 0 }, @@ -1659,7 +1659,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f1e9fae", + "id": "8b7185c5", "metadata": { "lines_to_next_cell": 0 }, @@ -1686,7 +1686,7 @@ }, { "cell_type": "markdown", - "id": "f2a20f61", + "id": "772e4231", "metadata": { "lines_to_next_cell": 0 }, @@ -1700,7 +1700,7 @@ }, { "cell_type": "markdown", - "id": "ec1bb346", + "id": "7d159c2b", "metadata": { "lines_to_next_cell": 0 }, @@ -1717,7 +1717,7 @@ { "cell_type": "code", "execution_count": null, - "id": "92da5e43", + "id": "13eb454f", "metadata": {}, "outputs": [], "source": [ @@ -1740,7 +1740,7 @@ { "cell_type": "code", "execution_count": null, - "id": "017f3089", + "id": "3ddc6a07", "metadata": { "lines_to_next_cell": 0 }, @@ -1749,7 +1749,7 @@ }, { "cell_type": "markdown", - "id": "e4153d66", + "id": "303338b1", "metadata": {}, "source": [ "

      Questions

      \n", @@ -1761,7 +1761,7 @@ }, { "cell_type": "markdown", - "id": "f9218632", + "id": "e18da1e0", "metadata": {}, "source": [ "

      Checkpoint 5

      \n", @@ -1779,7 +1779,7 @@ }, { "cell_type": "markdown", - "id": "7413f5e2", + "id": "b74f495b", "metadata": { "lines_to_next_cell": 0 }, @@ -1801,7 +1801,7 @@ { "cell_type": "code", "execution_count": null, - "id": "97bc18ec", + "id": "52f77a00", "metadata": { "tags": [ "solution"