-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_tf.py
68 lines (54 loc) · 1.7 KB
/
main_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# %%
# # import torch
import numpy as npcon
import matplotlib.pyplot as plt
import tensorflow as tf
# tf.config.experimental.set_per_process_memory_fraction(0.75)
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(enable = True, device = physical_devices[0])
# config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
# sess = tf.Session(config=config)
import FSHA
from datasets import *
xpriv, xpub = load_mnist()
batch_size = 64
id_setup = 4
hparams = {
'WGAN' : True,
'gradient_penalty' : 500., # orginal is 500
'style_loss' : None,
'lr_f' : 0.00001,
'lr_tilde' : 0.00001,
'lr_D' : 0.0001,
}
fsha = FSHA.FSHA(xpriv, xpub, id_setup-1, batch_size, hparams)
log_frequency = 500
LOG = fsha(10000, verbose=True, progress_bar=True, log_frequency=log_frequency)
def plot_log(ax, x, y, label):
ax.plot(x, y, color='black')
ax.set(title=label)
ax.grid()
n = 4
fix, ax = plt.subplots(1, n, figsize=(n*5, 3))
x = np.arange(0, len(LOG)) * log_frequency
plot_log(ax[0], x, LOG[:, 0], label='Loss $f$')
plot_log(ax[1], x, LOG[:, 1], label='Loss $\\tilde{f}$ and $\\tilde{f}^{-1}$')
plot_log(ax[2], x, LOG[:, 2], label='Loss $D$')
plot_log(ax[3], x, LOG[:, 3], label='Reconstruction error (VALIDATION)')
n = 20
X = getImagesDS(xpriv, n)
X_recovered, control = fsha.attack(X)
def plot(X):
n = len(X)
X = (X+1)/2
fig, ax = plt.subplots(1, n, figsize=(n*3,3))
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=-.05)
for i in range(n):
ax[i].imshow((X[i]), cmap='inferno');
ax[i].set(xticks=[], yticks=[])
ax[i].set_aspect('equal')
return fig
fig = plot(X)
fig = plot(X_recovered)