Skip to content

Commit

Permalink
bug fixed, added demo and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shivamsaboo17 committed Mar 29, 2019
1 parent 64b99d3 commit 713f5f1
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 2,259 deletions.
349 changes: 349 additions & 0 deletions demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,349 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Untitled1.ipynb",
"version": "0.3.2",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"cells": [
{
"metadata": {
"id": "C7Ppv-uc9DWM",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from elastic_weight_consolidation import ElasticWeightConsolidation"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "qe22sCzx9DWQ",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"from torch.utils.data import Dataset, DataLoader\n",
"from torchvision import datasets, transforms"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Y_LMkmXG9DWV",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"mnist_train = datasets.MNIST(\"../data\", train=True, download=True, transform=transforms.ToTensor())\n",
"mnist_test = datasets.MNIST(\"../data\", train=False, download=True, transform=transforms.ToTensor())\n",
"train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)\n",
"test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "YrKlgL6t9zJe",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"class LinearLayer(nn.Module):\n",
" def __init__(self, input_dim, output_dim, act='relu', use_bn=False):\n",
" super(LinearLayer, self).__init__()\n",
" self.use_bn = use_bn\n",
" self.lin = nn.Linear(input_dim, output_dim)\n",
" self.act = nn.ReLU() if act == 'relu' else act\n",
" if use_bn:\n",
" self.bn = nn.BatchNorm1d(output_dim)\n",
" def forward(self, x):\n",
" if self.use_bn:\n",
" return self.bn(self.act(self.lin(x)))\n",
" return self.act(self.lin(x))\n",
"\n",
"class Flatten(nn.Module):\n",
"\n",
" def forward(self, x):\n",
" return x.view(x.shape[0], -1)\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "44d9meQa9DWc",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"class BaseModel(nn.Module):\n",
" \n",
" def __init__(self, num_inputs, num_hidden, num_outputs):\n",
" super(BaseModel, self).__init__()\n",
" self.f1 = Flatten()\n",
" self.lin1 = LinearLayer(num_inputs, num_hidden, use_bn=True)\n",
" self.lin2 = LinearLayer(num_hidden, num_hidden, use_bn=True)\n",
" self.lin3 = nn.Linear(num_hidden, num_outputs)\n",
" \n",
" def forward(self, x):\n",
" return self.lin3(self.lin2(self.lin1(self.f1(x))))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "_17XW9359DWf",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"crit = nn.CrossEntropyLoss()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "SpBrwjk89DWi",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"ewc = ElasticWeightConsolidation(BaseModel(28 * 28, 100, 10), crit=crit, lr=1e-4)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "CMQGk-E19DWl",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"from tqdm import tqdm"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "gmbrFvJm9DWn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 87
},
"outputId": "db2ca466-76dc-4d1c-fa32-8dc672a12a8f"
},
"cell_type": "code",
"source": [
"for _ in range(4):\n",
" for input, target in tqdm(train_loader):\n",
" ewc.forward_backward_update(input, target)"
],
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 600/600 [00:07<00:00, 75.32it/s]\n",
"100%|██████████| 600/600 [00:07<00:00, 76.07it/s]\n",
"100%|██████████| 600/600 [00:08<00:00, 74.12it/s]\n",
"100%|██████████| 600/600 [00:08<00:00, 73.90it/s]\n"
],
"name": "stderr"
}
]
},
{
"metadata": {
"id": "8HwlRJkI9DWt",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"ewc.register_ewc_params(mnist_train, 100, 300)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "NvJW68IB9DWw",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"f_mnist_train = datasets.FashionMNIST(\"../data\", train=True, download=True, transform=transforms.ToTensor())\n",
"f_mnist_test = datasets.FashionMNIST(\"../data\", train=False, download=True, transform=transforms.ToTensor())\n",
"f_train_loader = DataLoader(f_mnist_train, batch_size = 100, shuffle=True)\n",
"f_test_loader = DataLoader(f_mnist_test, batch_size = 100, shuffle=False)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "SzQbVudz9DWy",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 87
},
"outputId": "bdcb55c5-d40a-4a7a-dca5-4652076e8033"
},
"cell_type": "code",
"source": [
"for _ in range(4):\n",
" for input, target in tqdm(f_train_loader):\n",
" ewc.forward_backward_update(input, target)"
],
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 600/600 [00:09<00:00, 62.14it/s]\n",
"100%|██████████| 600/600 [00:09<00:00, 66.03it/s]\n",
"100%|██████████| 600/600 [00:09<00:00, 66.56it/s]\n",
"100%|██████████| 600/600 [00:09<00:00, 65.95it/s]\n"
],
"name": "stderr"
}
]
},
{
"metadata": {
"id": "L8n6PX5w9DW2",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"ewc.register_ewc_params(f_mnist_train, 100, 300)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "fUqvbeO79DW4",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def accu(model, dataloader):\n",
" model = model.eval()\n",
" acc = 0\n",
" for input, target in dataloader:\n",
" o = model(input)\n",
" acc += (o.argmax(dim=1).long() == target).float().mean()\n",
" return acc / len(dataloader)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "aOIOBZhp9DW6",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "b260dee2-3c7d-4a9f-be83-8ac412a32f5c"
},
"cell_type": "code",
"source": [
"accu(ewc.model, f_test_loader)"
],
"execution_count": 47,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(0.8188)"
]
},
"metadata": {
"tags": []
},
"execution_count": 47
}
]
},
{
"metadata": {
"id": "hFdW_33Y9DW-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "38b2dad0-bfb4-48e5-ec01-d848cc8c1593"
},
"cell_type": "code",
"source": [
"accu(ewc.model, test_loader)"
],
"execution_count": 48,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor(0.7027)"
]
},
"metadata": {
"tags": []
},
"execution_count": 48
}
]
},
{
"metadata": {
"id": "Fkni7xkY-tRI",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
11 changes: 6 additions & 5 deletions elastic_weight_consolidation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import DataLoader


class ElasticWeightConsolidation:

def __init__(self, model, crit, lr=0.001, weight=0.1):
def __init__(self, model, crit, lr=0.001, weight=1000000):
self.model = model
self.weight = weight
self.crit = crit
Expand All @@ -20,7 +21,7 @@ def _update_mean_params(self):
self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())

def _update_fisher_params(self, current_ds, batch_size, num_batch):
dl = DataLoader(ds, batch_size, shuffle=True)
dl = DataLoader(current_ds, batch_size, shuffle=True)
log_liklihoods = []
for i, (input, target) in enumerate(dl):
if i > num_batch:
Expand Down Expand Up @@ -56,8 +57,8 @@ def forward_backward_update(self, input, target):
loss.backward()
self.optimizer.step()

def save(self):
torch.save(self.model)
def save(self, filename):
torch.save(self.model, filename)

def load(self, filename):
self.model = torch.load(filename)
Loading

0 comments on commit 713f5f1

Please sign in to comment.