Skip to content

Commit

Permalink
fix cwk
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP committed Jan 29, 2022
1 parent f126780 commit 729a55c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 20 deletions.
34 changes: 20 additions & 14 deletions con.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,26 @@
"\n",
"\n",
"\n",
"------------------------\n",
"training with 1 replays------------------------\n",
"\n",
"\n",
"\n",
" training with 1 replays\n",
"train \t 1: 1.5880 40.0% 71.2s\n",
"train \t 2: 1.1595 57.8% 71.2s\n",
"train \t 3: 0.9540 65.9% 71.5s\n",
"train \t 4: 0.8208 71.3% 71.4s\n",
"train \t 5: 0.7153 74.9% 71.3s\n",
"train \t 6: 0.6392 77.7% 71.2s\n",
"train \t 7: 0.5713 80.4% 71.1s\n",
"train \t 8: 0.5228 81.9% 71.0s\n",
"train \t 9: 0.4808 83.5% 70.9s\n",
"train \t 10: 0.4478 84.7% 70.8s\n",
"train \t 11: 0.4202 85.6% 70.8s\n"
"\n",
"train \t 1: 1.5868 40.0% 71.4s\n",
"train \t 2: 1.1593 57.8% 71.1s\n",
"train \t 3: 0.9608 65.6% 71.3s\n",
"train \t 4: 0.8251 70.8% 71.4s\n",
"train \t 5: 0.7317 74.5% 71.4s\n",
"train \t 6: 0.6537 77.3% 71.2s\n",
"train \t 7: 0.5980 79.3% 71.5s\n",
"train \t 8: 0.5447 81.1% 71.0s\n",
"train \t 9: 0.5089 82.7% 70.9s\n",
"train \t 10: 0.4680 84.0% 70.8s\n",
"train \t 11: 0.4419 84.8% 70.8s\n",
"train \t 12: 0.4091 85.8% 70.7s\n",
"train \t 13: 0.3880 86.7% 70.6s\n",
"train \t 14: 0.3713 87.1% 70.6s\n",
"train \t 15: 0.3504 87.9% 70.5s\n"
]
}
],
Expand All @@ -234,7 +238,7 @@
"srl = defaultdict(lambda : defaultdict(lambda : []))\n",
"\n",
"for K in training_with_replay_Ks:\n",
" print(f'\\n\\n\\n\\n\\n------------------------\\n\\n\\n\\n training with {K} replays')\n",
" print(f'\\n\\n\\n\\n\\ntraining with {K} replays------------------------\\n\\n\\n\\n')\n",
"\n",
" model = build_model(False)\n",
" optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)\n",
Expand Down Expand Up @@ -301,6 +305,7 @@
"free_logs = defaultdict(lambda : defaultdict(lambda :[]))\n",
"\n",
"for K in free_Ks:\n",
" print(f'\\n\\n\\n\\n\\ntraining with {K} replays------------------------\\n\\n\\n\\n')\n",
" model = build_model()\n",
" optimizer = optim.Adam(model.parameters())\n",
"\n",
Expand Down Expand Up @@ -345,6 +350,7 @@
"pgd_logs = defaultdict(lambda : defaultdict(lambda : []))\n",
"\n",
"for K in PGD_Ks:\n",
" print(f'\\n\\n\\n\\n\\ntraining with {K}-PGD------------------------\\n\\n\\n\\n')\n",
" model = build_model(False)\n",
" optimizer = optim.Adam(model.parameters())\n",
" \n",
Expand Down
4 changes: 3 additions & 1 deletion src/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def __call__(self, model, x, y):

correct_logit = p[torch.arange(p.shape[0]), y]
wrong_logit = ((1 - mask) * p - self.c * mask).max(axis=1)[0]
loss = F.relu(correct_logit - wrong_logit).sum()
loss = -F.relu(correct_logit - wrong_logit).sum()
loss.backward()

self.step(noise)

self.zero(model, noise)
self.finalize(noise)
return self.adv(x, noise)
Expand Down
14 changes: 9 additions & 5 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from calendar import EPOCH
from collections import defaultdict
import time
import torch
Expand Down Expand Up @@ -33,14 +34,15 @@ def add(self, v, m=1):
Track accuracy, loss and runtime
"""
class Logisticator:
def __init__(self) -> None:
def __init__(self, epoch=None) -> None:
self._acc = Collectinator()
self._loss = Collectinator()
self.acc = 0
self.loss = 0
self.now = time.time()
self.end_time = None

self.epoch = epoch

def add(self, acc, loss, m):
self._acc.add(acc, m)
self.acc = self._acc.mean
Expand Down Expand Up @@ -73,7 +75,7 @@ def accuracy(outputs, labels):
def train_with_replay(K, model, trainloader, optimizer, epoch,
input_func=lambda x, y: x,
after_func=lambda model: None):
logs = Logisticator()
logs = Logisticator(epoch)

model.train()

Expand All @@ -97,11 +99,12 @@ def train_with_replay(K, model, trainloader, optimizer, epoch,
logs.add(acc, loss.item(), inputs.size(0))
print(f'train \t {epoch + 1}: {logs}')
return logs

def run_val(model, testloader, epoch):
model.train(False)
# valdiation loss
with torch.no_grad():
logs = Logisticator()
logs = Logisticator(epoch)

for data in testloader:
inputs, labels = map(lambda x: x.cuda(), data)
Expand All @@ -113,10 +116,11 @@ def run_val(model, testloader, epoch):

print(f'val \t {epoch + 1}: {logs}')
return logs

def run_attacks(logholder, attacks, attack_names, model, testloader, epoch):
model.train(False)
for (attack, name) in zip(attacks, attack_names):
logs = Logisticator()
logs = Logisticator(epoch)
logholder[f'adv_test/{name}'].append(logs)
for data in testloader:
inputs, labels = map(lambda x: x.cuda(), data)
Expand Down

0 comments on commit 729a55c

Please sign in to comment.