From 4a1fa0fc9aabdf65ce9c4c52d4401ae40c043126 Mon Sep 17 00:00:00 2001 From: isaacsquires Date: Mon, 7 Feb 2022 19:51:19 +0000 Subject: [PATCH] Load and continue training existing model --- main.py | 2 +- src/train.py | 6 +++++- src/util.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 0eb9ced..68317b1 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,7 @@ def main(mode, offline, tag): if mode == 'train': netD, netG = networks.make_nets(c, overwrite) - train(c, netG, netD, offline=offline) + train(c, netG, netD, offline=offline, overwrite=overwrite) elif mode == 'generate': netD, netG = networks.make_nets(c, Training=0) diff --git a/src/train.py b/src/train.py index df15e57..bdb1cdc 100644 --- a/src/train.py +++ b/src/train.py @@ -7,7 +7,7 @@ import tifffile import time -def train(c, Gen, Disc, offline=True): +def train(c, Gen, Disc, offline=True, overwrite=True): """[summary] :param c: [description] @@ -45,6 +45,10 @@ def train(c, Gen, Disc, offline=True): optD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2)) optG = optim.Adam(netG.parameters(), lr=lrg, betas=(beta1, beta2)) + if not overwrite: + netG.load_state_dict(torch.load(f"{path}/Gen.pt")) + netD.load_state_dict(torch.load(f"{path}/Disc.pt")) + wandb_init(tag, offline) wandb.watch(netD, log='all', log_freq=100) wandb.watch(netG, log='all', log_freq=100) diff --git a/src/util.py b/src/util.py index ddc7baf..8bc2b9e 100644 --- a/src/util.py +++ b/src/util.py @@ -37,6 +37,7 @@ def check_existence(tag): raise SystemExit else: raise AssertionError("Incorrect argument entered.") + return True # set-up util @@ -217,7 +218,7 @@ def generate(c, netG): torch.cuda.is_available() and ngpu > 0) else "cpu") if (ngpu > 1): netG = nn.DataParallel(netG, list(range(ngpu))).to(device) - netG.load_state_dict(torch.load(f"{pth}/nets/Gen.pt")) + netG.load_state_dict(torch.load(f"{pth}/Gen.pt")) netG.eval() noise = torch.randn(1, nz, lf, lf, lf) raw = netG(noise)