Skip to content

Commit

Permalink
Fix issues with generate
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsquires committed Feb 7, 2022
1 parent 2f0a713 commit defaeda
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, tag):
self.Lambda = 10
self.critic_iters = 10
self.lz = 4
self.lf = 4
self.ngpu = 1
if self.ngpu > 0:
self.device_name = "cuda:0"
Expand Down
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ def main(mode, offline, tag):
:raises ValueError: [description]
"""
print("Running in {} mode, tagged {}, offline {}".format(mode, tag, offline))
overwrite = util.check_existence(tag)
util.initialise_folders(tag, overwrite)

# Initialise Config object
c = Config(tag)

if mode == 'train':
overwrite = util.check_existence(tag)
util.initialise_folders(tag, overwrite)
netD, netG = networks.make_nets(c, overwrite)
train(c, netG, netD, offline=offline, overwrite=overwrite)

elif mode == 'generate':
netD, netG = networks.make_nets(c, Training=0)
netD, netG = networks.make_nets(c, training=0)
net_g = netG()
util.generate(c, net_g)
print("Img generated")
Expand Down
2 changes: 1 addition & 1 deletion src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def generate(c, netG):
netG = nn.DataParallel(netG, list(range(ngpu))).to(device)
netG.load_state_dict(torch.load(f"{pth}/Gen.pt"))
netG.eval()
noise = torch.randn(1, nz, lf, lf, lf)
noise = torch.randn(1, nz, lf, lf)
raw = netG(noise)
gb = post_process(raw)
tif = np.array(gb[0], dtype=np.uint8)
Expand Down

0 comments on commit defaeda

Please sign in to comment.