Skip to content

Commit

Permalink
added scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
ebagdasa committed Apr 7, 2022
1 parent af12c08 commit 39da2f6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
9 changes: 4 additions & 5 deletions configs/cifar10_params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@ synthesizer: Pattern

batch_size: 64
test_batch_size: 100
lr: 0.001
lr: 0.1
momentum: 0.9
decay: 0.0005
epochs: 350
epochs: 200
save_on_epochs: [10, 50, 100]
optimizer: Adam
optimizer: SGD
log_interval: 100

pretrained: True

scheduler: True
scheduler_milestones: [30, 60]

poisoning_proportion: 1.1
backdoor_label: 8
backdoor: False
backdoor: True

loss_balance: MGDA
mgda_normalize: loss+
Expand Down
9 changes: 3 additions & 6 deletions tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import optim, nn
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.transforms import transforms

from metrics.accuracy_metric import AccuracyMetric
Expand All @@ -29,7 +29,7 @@ class Task:
model: Module = None
optimizer: optim.Optimizer = None
criterion: Module = None
scheduler: MultiStepLR = None
scheduler: CosineAnnealingLR = None
metrics: List[Metric] = None

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
Expand Down Expand Up @@ -85,10 +85,7 @@ def make_optimizer(self, model=None) -> Optimizer:

def make_scheduler(self) -> None:
if self.params.scheduler:
self.scheduler = MultiStepLR(self.optimizer,
milestones=self.params.scheduler_milestones,
last_epoch=self.params.start_epoch,
gamma=0.1)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=self.params.epochs)

def resume_model(self):
if self.params.resume_model:
Expand Down
3 changes: 2 additions & 1 deletion training.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def run(hlpr):
acc = test(hlpr, epoch, backdoor=False)
test(hlpr, epoch, backdoor=True)
hlpr.save_model(hlpr.task.model, epoch, acc)

if hlpr.task.scheduler is not None:
hlpr.task.scheduler.step(epoch)

def fl_run(hlpr: Helper):
for epoch in range(hlpr.params.start_epoch,
Expand Down

0 comments on commit 39da2f6

Please sign in to comment.