Skip to content

Commit

Permalink
Load and continue training existing model
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsquires committed Feb 7, 2022
1 parent 26cc255 commit 4a1fa0f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(mode, offline, tag):

if mode == 'train':
netD, netG = networks.make_nets(c, overwrite)
train(c, netG, netD, offline=offline)
train(c, netG, netD, offline=offline, overwrite=overwrite)

elif mode == 'generate':
netD, netG = networks.make_nets(c, Training=0)
Expand Down
6 changes: 5 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tifffile
import time

def train(c, Gen, Disc, offline=True):
def train(c, Gen, Disc, offline=True, overwrite=True):
"""[summary]
:param c: [description]
Expand Down Expand Up @@ -45,6 +45,10 @@ def train(c, Gen, Disc, offline=True):
optD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optG = optim.Adam(netG.parameters(), lr=lrg, betas=(beta1, beta2))

if not overwrite:
netG.load_state_dict(torch.load(f"{path}/Gen.pt"))
netD.load_state_dict(torch.load(f"{path}/Disc.pt"))

wandb_init(tag, offline)
wandb.watch(netD, log='all', log_freq=100)
wandb.watch(netG, log='all', log_freq=100)
Expand Down
3 changes: 2 additions & 1 deletion src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def check_existence(tag):
raise SystemExit
else:
raise AssertionError("Incorrect argument entered.")
return True


# set-up util
Expand Down Expand Up @@ -217,7 +218,7 @@ def generate(c, netG):
torch.cuda.is_available() and ngpu > 0) else "cpu")
if (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu))).to(device)
netG.load_state_dict(torch.load(f"{pth}/nets/Gen.pt"))
netG.load_state_dict(torch.load(f"{pth}/Gen.pt"))
netG.eval()
noise = torch.randn(1, nz, lf, lf, lf)
raw = netG(noise)
Expand Down

0 comments on commit 4a1fa0f

Please sign in to comment.