From 56302c7aa71923b7d724861293622ab81ae0261b Mon Sep 17 00:00:00 2001 From: adjavon Date: Sun, 27 Aug 2023 14:06:45 -0400 Subject: [PATCH] Larissa's review (#5) * A few corrections and two todos for Diane * Add counterfactul explanation and end hint * Update exercise * Update instructions * Update exercise with new instructions * Revert "Update exercise with new instructions" This reverts commit c8ba415b8869cc6ee1a1a440fe557fb5d91579eb. --------- Co-authored-by: msschwartz21 --- solution.ipynb | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/solution.ipynb b/solution.ipynb index 64f611c..ce40e13 100644 --- a/solution.ipynb +++ b/solution.ipynb @@ -184,10 +184,7 @@ "metadata": {}, "source": [ "We make a `torch` `DataLoader` that takes our `sampler` to create batches of eight images and their corresponding labels.\n", - "Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class).\n", - "\n", - "We additionally create a validation data loader and a test data loader.\n", - "These do not need to be sampled in a special way, and can load more images at once because the evaluation pass is less memory intensive than the training pass." + "Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class)." ] }, { @@ -486,15 +483,13 @@ "tags": [] }, "source": [ - "

Task 1.2: Create a prediction loop

\n", - "\n", - "We now have a classifier that can discriminate between images of different types. If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images.\n", + "

Task 1.2: Create a prediction function

\n", "\n", - "To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy and eventually a confusiom natrix.\n", + "To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy during training, and eventually a confusiom natrix. In practice, this will allow us to stop before we overfit, although in this exercise we will probably not be training that long. Then, later, we can use the same prediction function on test data.\n", "\n", "\n", "TODO\n", - "Modify the `evaluation` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n", + "Modify `predict` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n", "- Get the model output for the batch of data `(x, y)`\n", "- Turn the model output into a probability\n", "- Get the class predictions from the probabilities\n", @@ -707,7 +702,13 @@ " checkpoint = torch.load(\n", " \"checkpoints/synapses/classifier/vgg_checkpoint\", map_location=device\n", " )\n", - " model.load_state_dict(checkpoint[\"model_state_dict\"])" + " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", + "\n", + " \n", + "# And check the (hopefully much better) accuracy\n", + "predictions, ground_truths = predict(test_dataset, \"Test\")\n", + "accuracy = accuracy_score(ground_truths, predictions)\n", + "print(f\"Final_final_v2_last_one test accuracy: {accuracy}\")" ] }, { @@ -3240,6 +3241,10 @@ "display_name": "09_knowledge_extraction", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.4" } }, "nbformat": 4,