-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel_inversion_demo.py
91 lines (72 loc) · 3.37 KB
/
model_inversion_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sys
from pathlib import Path
import numpy as np
import torch
from torchvision import transforms, datasets
from torchvision.utils import save_image
import unsplit.attacks as unsplit
from unsplit.models import *
from unsplit.util import *
dataset = sys.argv[1]
split_layer = int(sys.argv[2])
# create results directory if doesn't exist.
Path("results").mkdir(parents=True, exist_ok=True)
# load datasets and initialize client, server, and clone models
if dataset == 'mnist':
trainset = datasets.MNIST('data/mnist', download=True, train=True, transform=transforms.ToTensor())
testset = datasets.MNIST('data/mnist', download=True, train=False, transform=transforms.ToTensor())
client, server, clone = MnistNet(), MnistNet(), MnistNet()
elif dataset == 'f_mnist':
trainset = datasets.FashionMNIST('data/f_mnist', download=True, train=True, transform=transforms.ToTensor())
testset = datasets.FashionMNIST('data/f_mnist', download=True, train=False, transform=transforms.ToTensor())
client, server, clone = MnistNet(), MnistNet(), MnistNet()
elif dataset == 'cifar':
trainset = datasets.CIFAR10('data/cifar', download=True, train=True, transform=transforms.ToTensor())
testset = datasets.CIFAR10('data/cifar', download=True, train=False, transform=transforms.ToTensor())
client, server, clone = CifarNet(), CifarNet(), CifarNet()
trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=64)
testloader = torch.utils.data.DataLoader(testset, shuffle=True)
# -- TRAIN MODELS --
print('Training models...')
client_opt = torch.optim.Adam(client.parameters(), lr=0.001, amsgrad=True)
server_opt = torch.optim.Adam(server.parameters(), lr=0.001, amsgrad=True)
criterion = torch.nn.CrossEntropyLoss()
epochs = 10
for epoch in range(epochs):
running_loss = 0
for images, labels in trainloader:
client_opt.zero_grad()
server_opt.zero_grad()
pred = server(client(images, end=split_layer), start=split_layer+1)
loss = criterion(pred, labels)
loss.backward()
running_loss += loss
server_opt.step()
client_opt.step()
else:
print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)} Acc: {get_test_score(client, server, testset, split=split_layer)}')
print('Done.')
# -- MODEL INVERSION & STEALING --
print('Starting model inversion & stealing attack...')
# load one example per class from the test set
inversion_targets = [get_examples_by_class(testset, c, count=1) for c in range(10)]
targetloader = torch.utils.data.DataLoader(inversion_targets, shuffle=False)
mse = torch.nn.MSELoss()
results, losses = [], []
for idx, target in enumerate(targetloader):
# obtain client output
client_out = client(target, end=split_layer)
# perform the attack
result = unsplit.model_inversion_stealing(clone, split_layer, client_out, target.size(),
main_iters=1000, input_iters=100, model_iters=100)
# save result
if dataset == 'cifar':
result = normalize(result)
results.append(result)
loss = mse(result, target)
losses.append(loss)
save_image(result, f'results/{dataset}_{idx}.png')
print(f'\tImage {idx} loss: {loss}')
print(f'Average MSE: {sum(losses) / len(losses)}')
print(f'Clone test score: {get_test_score(client, clone, testset, split=split_layer)}%')
print(f'Results saved to the results directory.')