Skip to content

Commit

Permalink
Larissa's review (#5)
Browse files Browse the repository at this point in the history
* A few corrections and two todos for Diane

* Add counterfactul explanation and end hint

* Update exercise

* Update instructions

* Update exercise with new instructions

* Revert "Update exercise with new instructions"

This reverts commit c8ba415.

---------

Co-authored-by: msschwartz21 <[email protected]>
  • Loading branch information
adjavon and msschwartz21 authored Aug 27, 2023
1 parent b3ab18f commit 56302c7
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,7 @@
"metadata": {},
"source": [
"We make a `torch` `DataLoader` that takes our `sampler` to create batches of eight images and their corresponding labels.\n",
"Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class).\n",
"\n",
"We additionally create a validation data loader and a test data loader.\n",
"These do not need to be sampled in a special way, and can load more images at once because the evaluation pass is less memory intensive than the training pass."
"Each image should be randomly and equally selected from the six available classes (i.e., for each image sample pick a random class, then pick a random image from this class)."
]
},
{
Expand Down Expand Up @@ -486,15 +483,13 @@
"tags": []
},
"source": [
"<div class=\"alert alert-block alert-info\"><h3>Task 1.2: Create a prediction loop</h3>\n",
"\n",
"We now have a classifier that can discriminate between images of different types. If you used the images we provided, the classifier is not perfect (you should get an accuracy of around 80%), but pretty good considering that there are six different types of images.\n",
"<div class=\"alert alert-block alert-info\"><h3>Task 1.2: Create a prediction function</h3>\n",
"\n",
"To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy and eventually a confusiom natrix.\n",
"To understand the performance of the classifier, we need to run predictions on the validation dataset so that we can get accuracy during training, and eventually a confusiom natrix. In practice, this will allow us to stop before we overfit, although in this exercise we will probably not be training that long. Then, later, we can use the same prediction function on test data.\n",
"\n",
"\n",
"TODO\n",
"Modify the `evaluation` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n",
"Modify `predict` so that it returns a paired list of predicted class vs ground truth to produce a confusion matrix. You'll need to do the following steps.\n",
"- Get the model output for the batch of data `(x, y)`\n",
"- Turn the model output into a probability\n",
"- Get the class predictions from the probabilities\n",
Expand Down Expand Up @@ -707,7 +702,13 @@
" checkpoint = torch.load(\n",
" \"checkpoints/synapses/classifier/vgg_checkpoint\", map_location=device\n",
" )\n",
" model.load_state_dict(checkpoint[\"model_state_dict\"])"
" model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
"\n",
" \n",
"# And check the (hopefully much better) accuracy\n",
"predictions, ground_truths = predict(test_dataset, \"Test\")\n",
"accuracy = accuracy_score(ground_truths, predictions)\n",
"print(f\"Final_final_v2_last_one test accuracy: {accuracy}\")"
]
},
{
Expand Down Expand Up @@ -3240,6 +3241,10 @@
"display_name": "09_knowledge_extraction",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 56302c7

Please sign in to comment.