Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/katiezzzzz/Gantastic
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeng committed Sep 13, 2021
2 parents 2e8f6f5 + 484b690 commit 6b8defe
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 22 deletions.
6 changes: 0 additions & 6 deletions cgan_earth/code_files/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,6 @@ def forward(self, noise, labels, Training=True, ratio=2):
for layer in self.gen:
x = layer(x)
# upsample to give output spatial size (img_length, img_length)
'''
if Training:
up = F.interpolate(x, size = (self.img_length+2, self.img_length+2))
else:
up = F.interpolate(x, size = (x.shape[-2], x.shape[-1]))
'''
return torch.sigmoid(self.final_conv(x))


Expand Down
16 changes: 15 additions & 1 deletion cgan_earth/code_files/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,19 @@ def circular_transit(label1_channel, label2_channel, cur_label, z_step_size, l_s
new_label, l_step, z_step, l_done_step, z_done_step = circular_transit(0, 1, new_label, 0.3, 0.5, 4, 2, l_step, z_step,
l_done_step, z_done_step)

print(new_label)

a = torch.randn((1, 3, 6, 13))
print(a)
#a[:, :, torch.randint(a.shape[-2], (int(a.shape[-2]/2),))][torch.randint(a.shape[-1], (int(a.shape[-1]/2),))] = 0
b = torch.zeros_like(a)
for idx0 in range(a.shape[0]):
for idx1 in range(a.shape[1]):
old = a[idx0][idx1].clone().flatten()
old[torch.randint(len(old), (int(len(old)/2),))] = 0
b[idx0][idx1] = old.reshape(a.shape[-2], a.shape[-1])
boolean = torch.eq(a, b)
print(b)
print(b[boolean])
print(b[~boolean])
b[boolean] = 0
print(b)
64 changes: 57 additions & 7 deletions cgan_earth/code_files/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def roll_video(path, label, netG, n_classes, z_dim=64, lf=4, device='cpu', ratio
original_noise = add_noise_dim(random, original_noise, 3)
else:
max_len = original_noise.shape[-1]

netG.eval()
test_label = gen_labels(label, n_classes)[:, :, None, None]
imgs = np.array([])
Expand Down Expand Up @@ -335,12 +335,9 @@ def transit_video(label1, label2, n_classes, original_noise, netG, lf=4, ratio=2
img = torch.multiply(img, 255).cpu().detach().numpy()
for i in range(num_img):
if step_size < 1:
if i == 0:
# one z represents 32 pixels in the -1 dimension
out = img[:, :, :, :img.shape[-1]-32]
else:
# currently only implemented for step_size 0.5
out = img[:, :, :, 16:img.shape[-1]-16]
# one z represents 32 pixels in the -1 dimension
step_idx = int(i * step_size * 32)
out = img[:, :, :, step_idx:img.shape[-1]-(32-step_idx)]
else:
out = img[:, :, :, :img.shape[-1]-32]
out = np.moveaxis(out, 1, -1)
Expand Down Expand Up @@ -368,6 +365,59 @@ def transit_video(label1, label2, n_classes, original_noise, netG, lf=4, ratio=2
noise = roll_noise(original_noise, step, max_step, IntStep)
return imgs, noise, netG

def change_noise(label, original_noise, netG, n_classes, z_dim=64, lf=4, device='cpu', ratio=2, n_clips=30, step_size=1, value=0.01, method='add'):
max_len = original_noise.shape[-1]

test_label = gen_labels(label, n_classes)[:, :, None, None]

lbl = test_label.repeat(1, 1, lf, max_len).to(device)
imgs = np.array([])
if method == 'combined':
noise, bool_tensor = vary_noise(original_noise, value/3, ratio=0.5)
else:
noise = original_noise
step = 0.0
if step_size >= 1:
num_img = 1
else:
num_img = int(1/step_size)
for _ in tqdm(range(n_clips)):
with torch.no_grad():
img = netG(noise, lbl, Training=False, ratio=ratio).cuda()
img = torch.multiply(img, 255).cpu().detach().numpy()
for i in range(num_img):
if step_size < 1:
# one z represents 32 pixels in the -1 dimension
step_idx = int(i * step_size * 32)
out = img[:, :, :, step_idx:img.shape[-1]-(32-step_idx)]
else:
out = img[:, :, :, :img.shape[-1]-32]
out = np.moveaxis(out, 1, -1)
if imgs.shape[0] == 0:
imgs = out
else:
imgs = np.vstack((imgs, out))
step += step_size
max_step = lf*ratio-2
if max_len == lf*ratio:
IntStep = True
else:
IntStep = False
if step_size < 1:
step_idx = 1
else:
step_idx = step_size
if step > max_step:
step -= max_step
noise = roll_noise(noise, step_idx, max_step, IntStep)
if method == 'add':
noise = torch.sub(noise, value)
elif method == 'sub':
noise = torch.add(noise, value)
elif method == 'combined':
noise, bool_tensor = vary_noise(noise, value/2, ratio=0.5, boolean=bool_tensor)
return imgs, noise, netG

def animate(path, imgs, fps=24):
clip = ImageSequenceClip(list(imgs),fps=fps)
clip.write_gif(path + '_demo1.gif')
Expand Down
23 changes: 23 additions & 0 deletions cgan_earth/code_files/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ def roll_noise(original_noise, step, max_step, IntStep=True):
out_noise = original_noise
return out_noise

def replace_noise(original_noise, z_dim, lf, ratio, device):
new_noise = torch.zeros_like(original_noise)
# keep z0, z1, ...
new_noise[:, :, :, :-1] = original_noise[:, :, :, :-1]
# slot in new random noise
new_noise[:, :, :, -1] = torch.randn(1, z_dim, lf, device=device)
return new_noise

def vary_noise(original_noise, value, ratio, boolean=None):
new_noise = torch.zeros_like(original_noise)
if boolean == None:
for idx0 in range(original_noise.shape[0]):
for idx1 in range(original_noise.shape[1]):
old = original_noise[idx0][idx1].clone().flatten()
old[torch.randint(len(old), (int(len(old)*ratio),))] = torch.add(old[torch.randint(len(old), (int(len(old)*ratio),))], value)
new_noise[idx0][idx1] = old.reshape(original_noise.shape[-2], original_noise.shape[-1])
boolean = torch.eq(original_noise, new_noise)
new_noise[boolean] = torch.sub(new_noise[boolean], value)
else:
new_noise[~boolean] = torch.add(original_noise[~boolean], value)
new_noise[boolean] = torch.sub(original_noise[boolean], value)
return new_noise, boolean

def uniform_transit(label1_channel, label2_channel, cur_label, l_step_size):
'''
Compute uniform transition form one label to the other
Expand Down
10 changes: 6 additions & 4 deletions cgan_earth/make_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@
sea_lbl = [3]
snow_lbl = [4]
lf = 16
ratio = 2
ratio = 4

# test1: forest, then transit to sea, then roll in sea
# the speed currently must start with the lowest possible speed to make sure the noise has right dimension
imgs1, noise, netG = roll_video(proj_path, snow_lbl, netG(z_dim+n_classes, img_length), n_classes, z_dim, lf=lf, device=device, ratio=ratio, n_clips=24*3, step_size=0.5)
imgs2, noise, netG = transit_video(snow_lbl, city_lbl, n_classes, noise, netG, lf=lf, ratio=ratio, device=device, step_size=0.25, z_step_size=0.1, l_step_size=0.1, transit_mode='uniform')
imgs3, noise, netG = roll_video(proj_path, city_lbl, netG, n_classes, z_dim, lf=lf, device=device, ratio=ratio, n_clips=24*3, step_size=0.5, original_noise=noise)
imgs1, noise, netG = roll_video(proj_path, sea_lbl, netG(z_dim+n_classes, img_length), n_classes, z_dim, lf=lf, device=device, ratio=ratio, n_clips=5, step_size=0.25)
imgs2, noise, netG = transit_video(sea_lbl, forest_lbl, n_classes, noise, netG, lf=lf, ratio=ratio, device=device, step_size=0.25, z_step_size=0.1, l_step_size=0.2, transit_mode='uniform')
imgs3, noise, netG = roll_video(proj_path, forest_lbl, netG, n_classes, z_dim, lf=lf, device=device, ratio=ratio, n_clips=15, step_size=0.25, original_noise=noise)
imgs4, noise, netG = change_noise(forest_lbl, noise, netG, n_classes, z_dim, lf=lf, device=device, ratio=ratio, n_clips=30, step_size=0.25, value=0.00001, method='combined')

# concatenante the imgs together and make video
imgs = np.vstack((imgs1, imgs2))
imgs = np.vstack((imgs, imgs3))
imgs = np.vstack((imgs, imgs4))
animate(proj_path, imgs, fps=24)
8 changes: 4 additions & 4 deletions cgan_earth/run_cgan_earth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import os

PATH = os.path.dirname(os.path.realpath(__file__))
Project_name = 'earth_cylinder'
Project_name = 'earth_cylinder_t'
Project_dir = PATH + '/trained_generators/'
wandb_name = Project_name

# import training images and define labels
data_path = []
labels = []

for img_path, label in zip(['forest1', 'city1', 'desert1', 'sea1', 'snow1', 'star1'], [0, 1, 2, 3, 4, 5]):
for img_path, label in zip(['forest1'], [0]):
file = PATH + '/earth_screenshots/{}.jpg'.format(img_path)
data_path.append(file) # path to training data
labels.append(label)
Expand All @@ -23,8 +23,8 @@
z_dim = 64
lr = 0.0001
Training = 1
n_classes = 6
batch_size = 12
n_classes = 1
batch_size = 2
im_channels = 3
num_epochs = 600
img_length = 128 # size of training image
Expand Down

0 comments on commit 6b8defe

Please sign in to comment.