\n",
@@ -529,7 +529,7 @@
},
{
"cell_type": "markdown",
- "id": "758de8e5",
+ "id": "9a69110b",
"metadata": {},
"source": [
"
Checkpoint 2
\n",
@@ -549,7 +549,7 @@
},
{
"cell_type": "markdown",
- "id": "9dd5ce8d",
+ "id": "d6cd057a",
"metadata": {
"lines_to_next_cell": 0
},
@@ -577,7 +577,7 @@
},
{
"cell_type": "markdown",
- "id": "4497f894",
+ "id": "003cd7d6",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -600,7 +600,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4eba79a8",
+ "id": "c99ee37d",
"metadata": {},
"outputs": [],
"source": [
@@ -632,7 +632,7 @@
},
{
"cell_type": "markdown",
- "id": "4e369a2c",
+ "id": "9a1ba5d0",
"metadata": {
"lines_to_next_cell": 0
},
@@ -647,7 +647,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7dfcc3a5",
+ "id": "0e255fc6",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -668,7 +668,7 @@
},
{
"cell_type": "markdown",
- "id": "2f351532",
+ "id": "d0d5441a",
"metadata": {
"tags": []
},
@@ -683,7 +683,7 @@
},
{
"cell_type": "markdown",
- "id": "fef243a6",
+ "id": "1a5da3b0",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -700,7 +700,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "84848da5",
+ "id": "a17d88eb",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -714,7 +714,7 @@
},
{
"cell_type": "markdown",
- "id": "331a5aa9",
+ "id": "6152ae96",
"metadata": {
"lines_to_next_cell": 0
},
@@ -725,7 +725,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5da1667d",
+ "id": "47ccc999",
"metadata": {},
"outputs": [],
"source": [
@@ -735,7 +735,7 @@
},
{
"cell_type": "markdown",
- "id": "52c90cba",
+ "id": "235ecf34",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -753,7 +753,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bbbf486c",
+ "id": "ecd33083",
"metadata": {
"lines_to_next_cell": 0
},
@@ -765,7 +765,7 @@
},
{
"cell_type": "markdown",
- "id": "900d4b51",
+ "id": "46499a30",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -784,7 +784,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ab86e2bc",
+ "id": "a75d1d8e",
"metadata": {},
"outputs": [],
"source": [
@@ -793,7 +793,7 @@
},
{
"cell_type": "markdown",
- "id": "bb7046e9",
+ "id": "6f1e7957",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -809,7 +809,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "049d4a98",
+ "id": "48c11783",
"metadata": {},
"outputs": [],
"source": [
@@ -818,7 +818,7 @@
},
{
"cell_type": "markdown",
- "id": "8c3e7170",
+ "id": "7a403168",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -830,7 +830,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "209395dd",
+ "id": "197f0b55",
"metadata": {},
"outputs": [],
"source": [
@@ -843,7 +843,7 @@
},
{
"cell_type": "markdown",
- "id": "2513fff6",
+ "id": "14642c7c",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -857,7 +857,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6413b39",
+ "id": "8567a4c7",
"metadata": {},
"outputs": [],
"source": [
@@ -869,7 +869,7 @@
},
{
"cell_type": "markdown",
- "id": "1022dd64",
+ "id": "d617a57b",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -889,7 +889,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3765f525",
+ "id": "152d1fc2",
"metadata": {},
"outputs": [],
"source": [
@@ -913,7 +913,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6dc12c2",
+ "id": "1beae4d6",
"metadata": {},
"outputs": [],
"source": [
@@ -923,7 +923,7 @@
},
{
"cell_type": "markdown",
- "id": "f572e6ab",
+ "id": "6fba9909",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -945,7 +945,7 @@
},
{
"cell_type": "markdown",
- "id": "0725d30d",
+ "id": "07dd61b7",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -957,7 +957,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6feaf7e3",
+ "id": "df4f3916",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -1068,7 +1068,7 @@
},
{
"cell_type": "markdown",
- "id": "cf7e2d33",
+ "id": "0574b61f",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1080,7 +1080,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f8f84461",
+ "id": "94723f87",
"metadata": {},
"outputs": [],
"source": [
@@ -1096,7 +1096,7 @@
},
{
"cell_type": "markdown",
- "id": "425b6eb8",
+ "id": "a734af98",
"metadata": {
"tags": []
},
@@ -1111,7 +1111,7 @@
},
{
"cell_type": "markdown",
- "id": "0f81593a",
+ "id": "34e2b5cd",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1123,7 +1123,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8fb9650f",
+ "id": "22f64934",
"metadata": {},
"outputs": [],
"source": [
@@ -1143,19 +1143,9 @@
"plt.show()"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cf9cfe92",
- "metadata": {
- "lines_to_next_cell": 0
- },
- "outputs": [],
- "source": []
- },
{
"cell_type": "markdown",
- "id": "c31b5230",
+ "id": "d8dfbe12",
"metadata": {
"tags": []
},
@@ -1171,7 +1161,7 @@
},
{
"cell_type": "markdown",
- "id": "afb5d3fe",
+ "id": "32991250",
"metadata": {
"tags": []
},
@@ -1181,7 +1171,7 @@
},
{
"cell_type": "markdown",
- "id": "96e6c29a",
+ "id": "216383e1",
"metadata": {
"tags": []
},
@@ -1198,7 +1188,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8fcd2b8a",
+ "id": "54fd4016",
"metadata": {
"title": "Loading the test dataset"
},
@@ -1218,7 +1208,7 @@
},
{
"cell_type": "markdown",
- "id": "e58a66c1",
+ "id": "1dbcc267",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1230,7 +1220,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b6984bec",
+ "id": "04b20178",
"metadata": {},
"outputs": [],
"source": [
@@ -1243,7 +1233,7 @@
},
{
"cell_type": "markdown",
- "id": "09a033e4",
+ "id": "ea19f383",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1253,12 +1243,12 @@
},
{
"cell_type": "markdown",
- "id": "8d6dafc8",
+ "id": "e10ea481",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
- "
Task 4: Create counterfactuals
\n",
+ "
Task 4.1: Create counterfactuals
\n",
"In the below, we will store the counterfactual images in the `counterfactuals` array.\n",
"\n",
"
\n",
@@ -1271,7 +1261,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "98d8e7f1",
+ "id": "24f75a7b",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -1307,7 +1297,7 @@
},
{
"cell_type": "markdown",
- "id": "5ed9dc6b",
+ "id": "44a96f81",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1319,7 +1309,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "83220035",
+ "id": "83523c2f",
"metadata": {},
"outputs": [],
"source": [
@@ -1332,7 +1322,7 @@
},
{
"cell_type": "markdown",
- "id": "48b400d0",
+ "id": "bee1aa55",
"metadata": {
"tags": []
},
@@ -1347,7 +1337,7 @@
},
{
"cell_type": "markdown",
- "id": "28a66131",
+ "id": "dbac0e2e",
"metadata": {
"tags": []
},
@@ -1358,7 +1348,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ba72cac7",
+ "id": "7d263995",
"metadata": {},
"outputs": [],
"source": [
@@ -1372,7 +1362,7 @@
},
{
"cell_type": "markdown",
- "id": "7919e4d2",
+ "id": "60081f81",
"metadata": {
"tags": []
},
@@ -1387,7 +1377,7 @@
},
{
"cell_type": "markdown",
- "id": "5727914c",
+ "id": "a037d368",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1402,15 +1392,16 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0db3b836",
+ "id": "308d7f21",
"metadata": {},
"outputs": [],
"source": [
+ "target_class = 0\n",
"batch_size = 4\n",
"batch = [random_test_mnist[i] for i in range(batch_size)]\n",
"x = torch.stack([b[0] for b in batch])\n",
"y = torch.tensor([b[1] for b in batch])\n",
- "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n",
+ "x_fake = torch.tensor(counterfactuals[target_class, :batch_size])\n",
"x = x.to(device).float()\n",
"y = y.to(device)\n",
"x_fake = x_fake.to(device).float()\n",
@@ -1422,7 +1413,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "eda02588",
+ "id": "12c79b2d",
"metadata": {
"title": "Another visualization function"
},
@@ -1442,7 +1433,7 @@
" ax1.imshow(counterfactual_image)\n",
" ax1.set_title(\"Counterfactual\")\n",
" ax1.axis(\"off\")\n",
- " ax2.imshow(np.abs(attribution))\n",
+ " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n",
" ax2.set_title(\"Attribution\")\n",
" ax2.axis(\"off\")\n",
" plt.show()"
@@ -1451,7 +1442,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8d6fc8bf",
+ "id": "aa24deb7",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1467,7 +1458,7 @@
},
{
"cell_type": "markdown",
- "id": "b723e27b",
+ "id": "1726fcb6",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1483,7 +1474,122 @@
},
{
"cell_type": "markdown",
- "id": "66ec380e",
+ "id": "93aff28e",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n",
+ "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n",
+ "Here we have a much simpler task so we have some advantages:\n",
+ "- The counterfactuals are perfect! They already change the bare minimum (trust me).\n",
+ "- The changes are not objects, but colors.\n",
+ "As such, we will do a much simpler linear interpolation between the images."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0389b58c",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Task 4.2: Interpolation
\n",
+ "Let's interpolate between the original image and the counterfactual image.\n",
+ "We will create 10 images in between the two, and classify them.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "61b9f9a0",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "num_interpolations = 15\n",
+ "alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]\n",
+ "interpolated_images = [\n",
+ " alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)\n",
+ "]\n",
+ "interpolated_images = torch.stack(interpolated_images)\n",
+ "interpolated_classifications = [\n",
+ " model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7e70b512",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "# Plot the results\n",
+ "idx = 0\n",
+ "fig, axs = plt.subplots(\n",
+ " batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)\n",
+ ")\n",
+ "for idx in range(batch_size):\n",
+ " # Plot the original image\n",
+ " axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n",
+ " axs[idx, 0].axis(\"off\")\n",
+ " # Use the class as the title\n",
+ " axs[idx, 0].set_title(f\"Image: y={y[idx].item()}\")\n",
+ " # Plot the counterfactual image\n",
+ " axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n",
+ " axs[idx, -1].axis(\"off\")\n",
+ " # Use the target class as the title\n",
+ " axs[idx, -1].set_title(f\"CF: y={target_class}\")\n",
+ " for i, ax in enumerate(axs[idx][1:-1]):\n",
+ " ax.imshow(\n",
+ " np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))\n",
+ " )\n",
+ " ax.axis(\"off\")\n",
+ " classification = torch.softmax(interpolated_classifications[i][idx], dim=0)\n",
+ " # Plot the classification as the title in order source classification | target classification\n",
+ " ax.set_title(\n",
+ " f\"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}\"\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e9126f01",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Take some time to look at the plot we just made.\n",
+ "On the very left are the images we randomly chose - it's class is shown in the title.\n",
+ "On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.\n",
+ "In between are the interpolated images - their title shows their classification as \"source classification | target classification\".\n",
+ "This is a lot to take in, so take your time! Once you're ready, we can move on to the questions."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d50ea18c",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Questions
\n",
+ "
\n",
+ "- Do the images change smoothly from one class to another?
\n",
+ "- Can you see any patterns in the changes?
\n",
+ "- What happens when the original image and the counterfactual image are of the same class?
\n",
+ "- Based on this, would you trust this classifier on unseen images more or less than you did before?
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a93a5437",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1493,12 +1599,13 @@
"- Created a StarGAN that can change the class of an image\n",
"- Evaluated the StarGAN on unseen data\n",
"- Used the StarGAN to create counterfactual images\n",
- "- Used the counterfactual images to highlight the differences between classes\n"
+ "- Used the counterfactual images to highlight the differences between classes\n",
+ "- Interpolated between the images to see how the classifier behaves\n"
]
},
{
"cell_type": "markdown",
- "id": "7592e83d",
+ "id": "b8f44b99",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1521,7 +1628,7 @@
},
{
"cell_type": "markdown",
- "id": "a7400c5c",
+ "id": "aafe8844",
"metadata": {},
"source": [
"
Task 5.1: Explore the style space
\n",
@@ -1533,7 +1640,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f5493fd4",
+ "id": "0128d7bd",
"metadata": {},
"outputs": [],
"source": [
@@ -1568,7 +1675,7 @@
},
{
"cell_type": "markdown",
- "id": "d57f4253",
+ "id": "609028bd",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1584,7 +1691,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4160a32f",
+ "id": "8b7185c5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1611,7 +1718,7 @@
},
{
"cell_type": "markdown",
- "id": "650c6766",
+ "id": "772e4231",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1625,7 +1732,7 @@
},
{
"cell_type": "markdown",
- "id": "f88d9484",
+ "id": "7d159c2b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1642,7 +1749,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0b57d333",
+ "id": "13eb454f",
"metadata": {},
"outputs": [],
"source": [
@@ -1665,7 +1772,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4bb3433e",
+ "id": "3ddc6a07",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1674,7 +1781,7 @@
},
{
"cell_type": "markdown",
- "id": "516e0112",
+ "id": "303338b1",
"metadata": {},
"source": [
"
Questions
\n",
@@ -1686,7 +1793,7 @@
},
{
"cell_type": "markdown",
- "id": "63cef16c",
+ "id": "e18da1e0",
"metadata": {},
"source": [
"
Checkpoint 5
\n",
@@ -1701,6 +1808,27 @@
"If you have any questions, feel free to ask them in the chat!\n",
"And check the Solutions exercise for a definite answer to how these classes are defined!"
]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b74f495b",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "# Bonus!\n",
+ "If you have extra time, you can try to break the StarGAN!\n",
+ "There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?\n",
+ "Some things you might want to try:\n",
+ "- What happens if you don't use the EMA model?\n",
+ "- What happens if you change the learning rates?\n",
+ "- What happens if you add a Sigmoid activation to the output of the style encoder?\n",
+ "See what else you can think of, and see how finnicky training a GAN can be!\n",
+ "\n",
+ "# %% [markdown] tags=[\"solution\"]\n",
+ "The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n",
+ "Check your style space again to see if you can see the patterns now!"
+ ]
}
],
"metadata": {
diff --git a/solution.ipynb b/solution.ipynb
index db548a4..dcdff0d 100644
--- a/solution.ipynb
+++ b/solution.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "46b2ed85",
+ "id": "d5acca24",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -29,7 +29,7 @@
},
{
"cell_type": "markdown",
- "id": "3ca92ed6",
+ "id": "cb662602",
"metadata": {
"lines_to_next_cell": 0
},
@@ -41,7 +41,7 @@
},
{
"cell_type": "markdown",
- "id": "5452b543",
+ "id": "7af25d33",
"metadata": {},
"source": [
"\n",
@@ -54,7 +54,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3f6a6be8",
+ "id": "f2d6db31",
"metadata": {
"lines_to_next_cell": 0
},
@@ -68,7 +68,7 @@
},
{
"cell_type": "markdown",
- "id": "d521dde9",
+ "id": "3ef3bfa9",
"metadata": {
"lines_to_next_cell": 0
},
@@ -84,7 +84,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6d1b1c06",
+ "id": "02be6af2",
"metadata": {},
"outputs": [],
"source": [
@@ -102,7 +102,7 @@
},
{
"cell_type": "markdown",
- "id": "90516193",
+ "id": "6a76f50d",
"metadata": {
"lines_to_next_cell": 0
},
@@ -113,7 +113,7 @@
},
{
"cell_type": "markdown",
- "id": "8ec094ee",
+ "id": "fae7773e",
"metadata": {
"lines_to_next_cell": 0
},
@@ -130,7 +130,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0801e1e5",
+ "id": "d0b3ad78",
"metadata": {
"tags": [
"solution"
@@ -154,7 +154,7 @@
},
{
"cell_type": "markdown",
- "id": "e8d2799d",
+ "id": "137cd457",
"metadata": {
"lines_to_next_cell": 0
},
@@ -165,7 +165,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e0a495ea",
+ "id": "3c09237d",
"metadata": {},
"outputs": [],
"source": [
@@ -192,7 +192,7 @@
},
{
"cell_type": "markdown",
- "id": "3e5df200",
+ "id": "92af9975",
"metadata": {},
"source": [
"# Part 2: Using Integrated Gradients to find what the classifier knows\n",
@@ -202,7 +202,7 @@
},
{
"cell_type": "markdown",
- "id": "f3b1fb74",
+ "id": "c6d16b90",
"metadata": {},
"source": [
"## Attributions through integrated gradients\n",
@@ -215,7 +215,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "df74345f",
+ "id": "29c8613f",
"metadata": {
"tags": []
},
@@ -233,7 +233,7 @@
},
{
"cell_type": "markdown",
- "id": "43f8c1b4",
+ "id": "ba7abfbe",
"metadata": {
"tags": []
},
@@ -249,7 +249,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "52eb1aa0",
+ "id": "c4e89e02",
"metadata": {
"tags": [
"solution"
@@ -273,7 +273,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9907f91b",
+ "id": "d4148689",
"metadata": {
"tags": []
},
@@ -286,7 +286,7 @@
},
{
"cell_type": "markdown",
- "id": "61bab00d",
+ "id": "276c658d",
"metadata": {
"lines_to_next_cell": 2,
"tags": []
@@ -298,7 +298,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "93d8ac21",
+ "id": "30d569fb",
"metadata": {
"tags": []
},
@@ -326,7 +326,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a2476939",
+ "id": "570a7ccd",
"metadata": {
"tags": []
},
@@ -339,7 +339,7 @@
},
{
"cell_type": "markdown",
- "id": "b7e03da8",
+ "id": "4f13b2f4",
"metadata": {
"lines_to_next_cell": 2
},
@@ -353,7 +353,7 @@
},
{
"cell_type": "markdown",
- "id": "627449ca",
+ "id": "8c8790d3",
"metadata": {
"lines_to_next_cell": 0
},
@@ -366,7 +366,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8c1cee24",
+ "id": "ab2e8a0d",
"metadata": {},
"outputs": [],
"source": [
@@ -378,7 +378,7 @@
" ax1.imshow(original_image)\n",
" ax1.set_title(\"Image\")\n",
" ax1.axis(\"off\")\n",
- " ax2.imshow(np.abs(attribution))\n",
+ " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n",
" ax2.set_title(\"Attribution\")\n",
" ax2.axis(\"off\")\n",
" plt.show()\n",
@@ -391,7 +391,7 @@
},
{
"cell_type": "markdown",
- "id": "e887a55c",
+ "id": "b031ab39",
"metadata": {
"lines_to_next_cell": 0
},
@@ -405,7 +405,7 @@
},
{
"cell_type": "markdown",
- "id": "3f5d9121",
+ "id": "e0a15147",
"metadata": {},
"source": [
"\n",
@@ -431,7 +431,7 @@
},
{
"cell_type": "markdown",
- "id": "b0ea4cb1",
+ "id": "10e59bd9",
"metadata": {},
"source": [
"
Task 2.3: Use random noise as a baseline
\n",
@@ -443,7 +443,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "72ceec17",
+ "id": "57012d02",
"metadata": {
"tags": [
"solution"
@@ -469,7 +469,7 @@
},
{
"cell_type": "markdown",
- "id": "9163c4d8",
+ "id": "82e2a1b3",
"metadata": {
"tags": []
},
@@ -483,7 +483,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8ae711b1",
+ "id": "97f80e4a",
"metadata": {
"tags": [
"solution"
@@ -511,7 +511,7 @@
},
{
"cell_type": "markdown",
- "id": "cb79d651",
+ "id": "446d0aab",
"metadata": {
"tags": []
},
@@ -527,7 +527,7 @@
},
{
"cell_type": "markdown",
- "id": "d102ec55",
+ "id": "c0107144",
"metadata": {},
"source": [
"
BONUS Task: Using different attributions.
\n",
@@ -541,7 +541,7 @@
},
{
"cell_type": "markdown",
- "id": "758de8e5",
+ "id": "9a69110b",
"metadata": {},
"source": [
"
Checkpoint 2
\n",
@@ -561,7 +561,7 @@
},
{
"cell_type": "markdown",
- "id": "9dd5ce8d",
+ "id": "d6cd057a",
"metadata": {
"lines_to_next_cell": 0
},
@@ -589,7 +589,7 @@
},
{
"cell_type": "markdown",
- "id": "4497f894",
+ "id": "003cd7d6",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -612,7 +612,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4eba79a8",
+ "id": "c99ee37d",
"metadata": {},
"outputs": [],
"source": [
@@ -644,7 +644,7 @@
},
{
"cell_type": "markdown",
- "id": "4e369a2c",
+ "id": "9a1ba5d0",
"metadata": {
"lines_to_next_cell": 0
},
@@ -659,7 +659,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "69fb25c6",
+ "id": "23a3a52c",
"metadata": {
"tags": [
"solution"
@@ -676,7 +676,7 @@
},
{
"cell_type": "markdown",
- "id": "2f351532",
+ "id": "d0d5441a",
"metadata": {
"tags": []
},
@@ -691,7 +691,7 @@
},
{
"cell_type": "markdown",
- "id": "fef243a6",
+ "id": "1a5da3b0",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -708,7 +708,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1923e528",
+ "id": "7488ac02",
"metadata": {
"lines_to_next_cell": 0,
"tags": [
@@ -722,7 +722,7 @@
},
{
"cell_type": "markdown",
- "id": "331a5aa9",
+ "id": "6152ae96",
"metadata": {
"lines_to_next_cell": 0
},
@@ -733,7 +733,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5da1667d",
+ "id": "47ccc999",
"metadata": {},
"outputs": [],
"source": [
@@ -743,7 +743,7 @@
},
{
"cell_type": "markdown",
- "id": "52c90cba",
+ "id": "235ecf34",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -761,7 +761,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bbbf486c",
+ "id": "ecd33083",
"metadata": {
"lines_to_next_cell": 0
},
@@ -773,7 +773,7 @@
},
{
"cell_type": "markdown",
- "id": "900d4b51",
+ "id": "46499a30",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -792,7 +792,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ab86e2bc",
+ "id": "a75d1d8e",
"metadata": {},
"outputs": [],
"source": [
@@ -801,7 +801,7 @@
},
{
"cell_type": "markdown",
- "id": "bb7046e9",
+ "id": "6f1e7957",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -817,7 +817,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "049d4a98",
+ "id": "48c11783",
"metadata": {},
"outputs": [],
"source": [
@@ -826,7 +826,7 @@
},
{
"cell_type": "markdown",
- "id": "8c3e7170",
+ "id": "7a403168",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -838,7 +838,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "209395dd",
+ "id": "197f0b55",
"metadata": {},
"outputs": [],
"source": [
@@ -851,7 +851,7 @@
},
{
"cell_type": "markdown",
- "id": "2513fff6",
+ "id": "14642c7c",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -865,7 +865,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6413b39",
+ "id": "8567a4c7",
"metadata": {},
"outputs": [],
"source": [
@@ -877,7 +877,7 @@
},
{
"cell_type": "markdown",
- "id": "1022dd64",
+ "id": "d617a57b",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -897,7 +897,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3765f525",
+ "id": "152d1fc2",
"metadata": {},
"outputs": [],
"source": [
@@ -921,7 +921,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e6dc12c2",
+ "id": "1beae4d6",
"metadata": {},
"outputs": [],
"source": [
@@ -931,7 +931,7 @@
},
{
"cell_type": "markdown",
- "id": "f572e6ab",
+ "id": "6fba9909",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -953,7 +953,7 @@
},
{
"cell_type": "markdown",
- "id": "0725d30d",
+ "id": "07dd61b7",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -965,7 +965,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "a50de214",
+ "id": "5e54a534",
"metadata": {
"lines_to_next_cell": 2,
"tags": [
@@ -1035,7 +1035,7 @@
},
{
"cell_type": "markdown",
- "id": "cf7e2d33",
+ "id": "0574b61f",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1047,7 +1047,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f8f84461",
+ "id": "94723f87",
"metadata": {},
"outputs": [],
"source": [
@@ -1063,7 +1063,7 @@
},
{
"cell_type": "markdown",
- "id": "425b6eb8",
+ "id": "a734af98",
"metadata": {
"tags": []
},
@@ -1078,7 +1078,7 @@
},
{
"cell_type": "markdown",
- "id": "0f81593a",
+ "id": "34e2b5cd",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1090,7 +1090,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8fb9650f",
+ "id": "22f64934",
"metadata": {},
"outputs": [],
"source": [
@@ -1110,19 +1110,9 @@
"plt.show()"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cf9cfe92",
- "metadata": {
- "lines_to_next_cell": 0
- },
- "outputs": [],
- "source": []
- },
{
"cell_type": "markdown",
- "id": "c31b5230",
+ "id": "d8dfbe12",
"metadata": {
"tags": []
},
@@ -1138,7 +1128,7 @@
},
{
"cell_type": "markdown",
- "id": "afb5d3fe",
+ "id": "32991250",
"metadata": {
"tags": []
},
@@ -1148,7 +1138,7 @@
},
{
"cell_type": "markdown",
- "id": "96e6c29a",
+ "id": "216383e1",
"metadata": {
"tags": []
},
@@ -1165,7 +1155,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8fcd2b8a",
+ "id": "54fd4016",
"metadata": {
"title": "Loading the test dataset"
},
@@ -1185,7 +1175,7 @@
},
{
"cell_type": "markdown",
- "id": "e58a66c1",
+ "id": "1dbcc267",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1197,7 +1187,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b6984bec",
+ "id": "04b20178",
"metadata": {},
"outputs": [],
"source": [
@@ -1210,7 +1200,7 @@
},
{
"cell_type": "markdown",
- "id": "09a033e4",
+ "id": "ea19f383",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1220,12 +1210,12 @@
},
{
"cell_type": "markdown",
- "id": "8d6dafc8",
+ "id": "e10ea481",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
- "
Task 4: Create counterfactuals
\n",
+ "
Task 4.1: Create counterfactuals
\n",
"In the below, we will store the counterfactual images in the `counterfactuals` array.\n",
"\n",
"
\n",
@@ -1238,7 +1228,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "51f8114a",
+ "id": "780e6fdd",
"metadata": {
"tags": [
"solution"
@@ -1275,7 +1265,7 @@
},
{
"cell_type": "markdown",
- "id": "5ed9dc6b",
+ "id": "44a96f81",
"metadata": {
"lines_to_next_cell": 0,
"tags": []
@@ -1287,7 +1277,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "83220035",
+ "id": "83523c2f",
"metadata": {},
"outputs": [],
"source": [
@@ -1300,7 +1290,7 @@
},
{
"cell_type": "markdown",
- "id": "48b400d0",
+ "id": "bee1aa55",
"metadata": {
"tags": []
},
@@ -1315,7 +1305,7 @@
},
{
"cell_type": "markdown",
- "id": "28a66131",
+ "id": "dbac0e2e",
"metadata": {
"tags": []
},
@@ -1326,7 +1316,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "ba72cac7",
+ "id": "7d263995",
"metadata": {},
"outputs": [],
"source": [
@@ -1340,7 +1330,7 @@
},
{
"cell_type": "markdown",
- "id": "7919e4d2",
+ "id": "60081f81",
"metadata": {
"tags": []
},
@@ -1355,7 +1345,7 @@
},
{
"cell_type": "markdown",
- "id": "5727914c",
+ "id": "a037d368",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1370,15 +1360,16 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0db3b836",
+ "id": "308d7f21",
"metadata": {},
"outputs": [],
"source": [
+ "target_class = 0\n",
"batch_size = 4\n",
"batch = [random_test_mnist[i] for i in range(batch_size)]\n",
"x = torch.stack([b[0] for b in batch])\n",
"y = torch.tensor([b[1] for b in batch])\n",
- "x_fake = torch.tensor(counterfactuals[0, :batch_size])\n",
+ "x_fake = torch.tensor(counterfactuals[target_class, :batch_size])\n",
"x = x.to(device).float()\n",
"y = y.to(device)\n",
"x_fake = x_fake.to(device).float()\n",
@@ -1390,7 +1381,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "eda02588",
+ "id": "12c79b2d",
"metadata": {
"title": "Another visualization function"
},
@@ -1410,7 +1401,7 @@
" ax1.imshow(counterfactual_image)\n",
" ax1.set_title(\"Counterfactual\")\n",
" ax1.axis(\"off\")\n",
- " ax2.imshow(np.abs(attribution))\n",
+ " ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))\n",
" ax2.set_title(\"Attribution\")\n",
" ax2.axis(\"off\")\n",
" plt.show()"
@@ -1419,7 +1410,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8d6fc8bf",
+ "id": "aa24deb7",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1435,7 +1426,7 @@
},
{
"cell_type": "markdown",
- "id": "b723e27b",
+ "id": "1726fcb6",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1451,7 +1442,122 @@
},
{
"cell_type": "markdown",
- "id": "66ec380e",
+ "id": "93aff28e",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.\n",
+ "This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.\n",
+ "Here we have a much simpler task so we have some advantages:\n",
+ "- The counterfactuals are perfect! They already change the bare minimum (trust me).\n",
+ "- The changes are not objects, but colors.\n",
+ "As such, we will do a much simpler linear interpolation between the images."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0389b58c",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Task 4.2: Interpolation
\n",
+ "Let's interpolate between the original image and the counterfactual image.\n",
+ "We will create 10 images in between the two, and classify them.\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "61b9f9a0",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "num_interpolations = 15\n",
+ "alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]\n",
+ "interpolated_images = [\n",
+ " alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)\n",
+ "]\n",
+ "interpolated_images = torch.stack(interpolated_images)\n",
+ "interpolated_classifications = [\n",
+ " model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7e70b512",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "# Plot the results\n",
+ "idx = 0\n",
+ "fig, axs = plt.subplots(\n",
+ " batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)\n",
+ ")\n",
+ "for idx in range(batch_size):\n",
+ " # Plot the original image\n",
+ " axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n",
+ " axs[idx, 0].axis(\"off\")\n",
+ " # Use the class as the title\n",
+ " axs[idx, 0].set_title(f\"Image: y={y[idx].item()}\")\n",
+ " # Plot the counterfactual image\n",
+ " axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))\n",
+ " axs[idx, -1].axis(\"off\")\n",
+ " # Use the target class as the title\n",
+ " axs[idx, -1].set_title(f\"CF: y={target_class}\")\n",
+ " for i, ax in enumerate(axs[idx][1:-1]):\n",
+ " ax.imshow(\n",
+ " np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))\n",
+ " )\n",
+ " ax.axis(\"off\")\n",
+ " classification = torch.softmax(interpolated_classifications[i][idx], dim=0)\n",
+ " # Plot the classification as the title in order source classification | target classification\n",
+ " ax.set_title(\n",
+ " f\"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}\"\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e9126f01",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Take some time to look at the plot we just made.\n",
+ "On the very left are the images we randomly chose - it's class is shown in the title.\n",
+ "On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.\n",
+ "In between are the interpolated images - their title shows their classification as \"source classification | target classification\".\n",
+ "This is a lot to take in, so take your time! Once you're ready, we can move on to the questions."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d50ea18c",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Questions
\n",
+ "
\n",
+ "- Do the images change smoothly from one class to another?
\n",
+ "- Can you see any patterns in the changes?
\n",
+ "- What happens when the original image and the counterfactual image are of the same class?
\n",
+ "- Based on this, would you trust this classifier on unseen images more or less than you did before?
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a93a5437",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1461,12 +1567,13 @@
"- Created a StarGAN that can change the class of an image\n",
"- Evaluated the StarGAN on unseen data\n",
"- Used the StarGAN to create counterfactual images\n",
- "- Used the counterfactual images to highlight the differences between classes\n"
+ "- Used the counterfactual images to highlight the differences between classes\n",
+ "- Interpolated between the images to see how the classifier behaves\n"
]
},
{
"cell_type": "markdown",
- "id": "7592e83d",
+ "id": "b8f44b99",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1489,7 +1596,7 @@
},
{
"cell_type": "markdown",
- "id": "a7400c5c",
+ "id": "aafe8844",
"metadata": {},
"source": [
"
Task 5.1: Explore the style space
\n",
@@ -1501,7 +1608,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f5493fd4",
+ "id": "0128d7bd",
"metadata": {},
"outputs": [],
"source": [
@@ -1536,7 +1643,7 @@
},
{
"cell_type": "markdown",
- "id": "d57f4253",
+ "id": "609028bd",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1552,7 +1659,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4160a32f",
+ "id": "8b7185c5",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1579,7 +1686,7 @@
},
{
"cell_type": "markdown",
- "id": "650c6766",
+ "id": "772e4231",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1593,7 +1700,7 @@
},
{
"cell_type": "markdown",
- "id": "f88d9484",
+ "id": "7d159c2b",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1610,7 +1717,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0b57d333",
+ "id": "13eb454f",
"metadata": {},
"outputs": [],
"source": [
@@ -1633,7 +1740,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4bb3433e",
+ "id": "3ddc6a07",
"metadata": {
"lines_to_next_cell": 0
},
@@ -1642,7 +1749,7 @@
},
{
"cell_type": "markdown",
- "id": "516e0112",
+ "id": "303338b1",
"metadata": {},
"source": [
"
Questions
\n",
@@ -1654,7 +1761,7 @@
},
{
"cell_type": "markdown",
- "id": "63cef16c",
+ "id": "e18da1e0",
"metadata": {},
"source": [
"
Checkpoint 5
\n",
@@ -1672,14 +1779,21 @@
},
{
"cell_type": "markdown",
- "id": "e0cb271b",
+ "id": "b74f495b",
"metadata": {
- "lines_to_next_cell": 0,
- "tags": [
- "solution"
- ]
+ "lines_to_next_cell": 0
},
"source": [
+ "# Bonus!\n",
+ "If you have extra time, you can try to break the StarGAN!\n",
+ "There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?\n",
+ "Some things you might want to try:\n",
+ "- What happens if you don't use the EMA model?\n",
+ "- What happens if you change the learning rates?\n",
+ "- What happens if you add a Sigmoid activation to the output of the style encoder?\n",
+ "See what else you can think of, and see how finnicky training a GAN can be!\n",
+ "\n",
+ "# %% [markdown] tags=[\"solution\"]\n",
"The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.\n",
"Check your style space again to see if you can see the patterns now!"
]
@@ -1687,7 +1801,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "86681224",
+ "id": "52f77a00",
"metadata": {
"tags": [
"solution"
diff --git a/solution.py b/solution.py
index cd92aa3..e8c7455 100644
--- a/solution.py
+++ b/solution.py
@@ -222,7 +222,7 @@ def visualize_color_attribution(attribution, original_image):
ax1.imshow(original_image)
ax1.set_title("Image")
ax1.axis("off")
- ax2.imshow(np.abs(attribution))
+ ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))
ax2.set_title("Attribution")
ax2.axis("off")
plt.show()
@@ -768,7 +768,6 @@ def copy_parameters(source_model, target_model):
ax.axis("off")
plt.show()
-# %%
# %% [markdown] tags=[]
#
Checkpoint 3
# You've now learned the basics of what makes up a StarGAN, and details on how to perform adversarial training.
@@ -814,7 +813,7 @@ def copy_parameters(source_model, target_model):
# %% [markdown]
# Now we need to use these prototypes to create counterfactual images!
# %% [markdown]
-#
Task 4: Create counterfactuals
+#
Task 4.1: Create counterfactuals
# In the below, we will store the counterfactual images in the `counterfactuals` array.
#
#
@@ -917,11 +916,12 @@ def copy_parameters(source_model, target_model):
# Let's try putting the two together to see if we can figure out what exactly makes a class.
#
# %%
+target_class = 0
batch_size = 4
batch = [random_test_mnist[i] for i in range(batch_size)]
x = torch.stack([b[0] for b in batch])
y = torch.tensor([b[1] for b in batch])
-x_fake = torch.tensor(counterfactuals[0, :batch_size])
+x_fake = torch.tensor(counterfactuals[target_class, :batch_size])
x = x.to(device).float()
y = y.to(device)
x_fake = x_fake.to(device).float()
@@ -945,7 +945,7 @@ def visualize_color_attribution_and_counterfactual(
ax1.imshow(counterfactual_image)
ax1.set_title("Counterfactual")
ax1.axis("off")
- ax2.imshow(np.abs(attribution))
+ ax2.imshow(np.abs(attribution) / np.max(np.abs(attribution)))
ax2.set_title("Attribution")
ax2.axis("off")
plt.show()
@@ -967,12 +967,76 @@ def visualize_color_attribution_and_counterfactual(
#
#
# %% [markdown]
+# In the lecture, we used the attribution to act as a mask, to gradually go from the original image to the counterfactual image.
+# This allowed us to classify all of the intermediate images, and learn how the class changed over the interpolation.
+# Here we have a much simpler task so we have some advantages:
+# - The counterfactuals are perfect! They already change the bare minimum (trust me).
+# - The changes are not objects, but colors.
+# As such, we will do a much simpler linear interpolation between the images.
+# %% [markdown]
+#
Task 4.2: Interpolation
+# Let's interpolate between the original image and the counterfactual image.
+# We will create 10 images in between the two, and classify them.
+#
+# %%
+num_interpolations = 15
+alpha = np.linspace(0, 1, num_interpolations + 2)[1:-1]
+interpolated_images = [
+ alpha[i] * x_fake + (1 - alpha[i]) * x for i in range(num_interpolations)
+]
+interpolated_images = torch.stack(interpolated_images)
+interpolated_classifications = [
+ model(interpolated_images[idx].to(device)) for idx in range(num_interpolations)
+]
+# %%
+# Plot the results
+idx = 0
+fig, axs = plt.subplots(
+ batch_size, num_interpolations + 2, figsize=(30, 2 * batch_size)
+)
+for idx in range(batch_size):
+ # Plot the original image
+ axs[idx, 0].imshow(np.transpose(x[idx].cpu().squeeze().numpy(), (1, 2, 0)))
+ axs[idx, 0].axis("off")
+ # Use the class as the title
+ axs[idx, 0].set_title(f"Image: y={y[idx].item()}")
+ # Plot the counterfactual image
+ axs[idx, -1].imshow(np.transpose(x_fake[idx].cpu().squeeze().numpy(), (1, 2, 0)))
+ axs[idx, -1].axis("off")
+ # Use the target class as the title
+ axs[idx, -1].set_title(f"CF: y={target_class}")
+ for i, ax in enumerate(axs[idx][1:-1]):
+ ax.imshow(
+ np.transpose(interpolated_images[i][idx].cpu().squeeze().numpy(), (1, 2, 0))
+ )
+ ax.axis("off")
+ classification = torch.softmax(interpolated_classifications[i][idx], dim=0)
+ # Plot the classification as the title in order source classification | target classification
+ ax.set_title(
+ f"{classification[y[idx]].item():.2f} | {classification[target_class].item():.2f}"
+ )
+# %% [markdown]
+# Take some time to look at the plot we just made.
+# On the very left are the images we randomly chose - it's class is shown in the title.
+# On the very right are the counterfactual images, all of them made with the same prototype as a style source - the target class is shown in the title.
+# In between are the interpolated images - their title shows their classification as "source classification | target classification".
+# This is a lot to take in, so take your time! Once you're ready, we can move on to the questions.
+# %% [markdown]
+#
Questions
+#
+# - Do the images change smoothly from one class to another?
+# - Can you see any patterns in the changes?
+# - What happens when the original image and the counterfactual image are of the same class?
+# - Based on this, would you trust this classifier on unseen images more or less than you did before?
+#
+# %% [markdown]
#
Checkpoint 4
# At this point you have:
# - Created a StarGAN that can change the class of an image
# - Evaluated the StarGAN on unseen data
# - Used the StarGAN to create counterfactual images
# - Used the counterfactual images to highlight the differences between classes
+# - Interpolated between the images to see how the classifier behaves
#
# %% [markdown]
# # Part 5: Exploring the Style Space, finding the answer
@@ -1100,7 +1164,17 @@ def visualize_color_attribution_and_counterfactual(
# If you have any questions, feel free to ask them in the chat!
# And check the Solutions exercise for a definite answer to how these classes are defined!
-# %% [markdown] tags=["solution"]
+# %% [markdown]
+# # Bonus!
+# If you have extra time, you can try to break the StarGAN!
+# There are a lot of little things that we did to make sure that it runs correctly - but what if we didn't?
+# Some things you might want to try:
+# - What happens if you don't use the EMA model?
+# - What happens if you change the learning rates?
+# - What happens if you add a Sigmoid activation to the output of the style encoder?
+# See what else you can think of, and see how finnicky training a GAN can be!
+
+## %% [markdown] tags=["solution"]
# The colors for the classes are sampled from matplotlib colormaps! They are the four seasons: spring, summer, autumn, and winter.
# Check your style space again to see if you can see the patterns now!
# %% tags=["solution"]