-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeepLeakageFromGradients.py
211 lines (166 loc) · 6.68 KB
/
deepLeakageFromGradients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# -*- coding: utf-8 -*-
"""Deep Leakage from Gradients.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/gist/Lyken17/91b81526a8245a028d4f85ccc9191884/deep-leakage-from-gradients.ipynb
## Deep Leakage from Gradients
*Modified by Joshua Towell (2128037) on 04/04/22*
**Original code reference:**\
Zhu, L. (2019) Deep Leakage From Gradients [software]. MIT HAN Lab.
Available from https://github.com/mit-han-lab/dlg [accessed 4 April 2022].
### Modifications:
* Changes to configuration of DLG algorithm and datasets used
* Fully commented and explained code
"""
# Commented out IPython magic to ensure Python compatibility.
# Import necessary libraries and external modules
# PyTorch and Torch vision contain pre-defined functions and datasets relevant to ML
# %matplotlib inline
import numpy as np
from pprint import pprint
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
torch.manual_seed(50)
# PyTorch version 1.10.0+cu111 and TorchVision version 0.11.1+cu111 used
print(torch.__version__, torchvision.__version__)
# Select predefined dataset (used CIFAR100 and CIFAR10)
# Dataset selected and transformation functions for pre-processing are defined here
dst = datasets.CIFAR100("~/.torch", download=True)
tp = transforms.Compose([
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor()
])
tt = transforms.ToPILImage()
# Setting up environment for training and processing (CUDA GPU if available, otherwise CPU)
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
print("Running on %s" % device)
# Function for transforming image labels to binary arrays
def label_to_onehot(target, num_classes=100):
target = torch.unsqueeze(target, 1)
onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
onehot_target.scatter_(1, target, 1)
return onehot_target
# Function for label smoothing label definitions
def cross_entropy_for_onehot(pred, target):
return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))
# Preparing neural network for model
# Defining weights and bias for model
def weights_init(m):
if hasattr(m, "weight"):
m.weight.data.uniform_(-0.5, 0.5)
if hasattr(m, "bias"):
m.bias.data.uniform_(-0.5, 0.5)
# Defining LeNet neural network used for training
# Selected Sigmoid activation function
# Defining convolution layers in specified sequence
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
act = nn.Sigmoid
self.body = nn.Sequential(
nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
act(),
)
self.fc = nn.Sequential(
nn.Linear(768, 100)
)
def forward(self, x):
out = self.body(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
# Instantiating LeNet model for given device, weights and criteria
net = LeNet().to(device)
net.apply(weights_init)
criterion = cross_entropy_for_onehot
# Choosing participant (image) and applying pre-processing algorithms
# Creation of data and label to input for model
# Used images 45, 48 and 51 in both datasets
img_index = 45
# Transformed and rescaled images
gt_data = tp(dst[img_index][0]).to(device)
gt_data = gt_data.view(1, *gt_data.size())
gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
gt_label = gt_label.view(1, )
gt_onehot_label = label_to_onehot(gt_label, num_classes=100)
# Display ground truth training data and label
plt.imshow(tt(gt_data[0].cpu()))
plt.title("Ground truth image")
print("GT label is %d." % gt_label.item(), "\nOnehot label is %d." % torch.argmax(gt_onehot_label, dim=-1).item())
# Compute original gradient using model
out = net(gt_data)
y = criterion(out, gt_onehot_label)
dy_dx = torch.autograd.grad(y, net.parameters())
# Exchange gradient with other training nodes
original_dy_dx = list((_.detach().clone() for _ in dy_dx))
# Creating an intial random "dummy" image
# Generate randomised dummy data and label
dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)
# Display dummy data and label
plt.imshow(tt(dummy_data[0].cpu()))
plt.title("Dummy data")
print("Dummy label is %d." % torch.argmax(dummy_label, dim=-1).item())
# Select optimiser and begin iterative attack on shared gradients to re-construct original image
# Choose the LBFGS optimisation algorithm
optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )
# Store iteration and data loss values
print("Iter. Loss")
ax_iter = []
ax_loss = []
# Store each image iteration
history = []
# For each of the 200 iterations, predict new gradient for dummy data and apply the neural network to converge towards the original training data values
# Return the gradient distance ready for the next iteration
for iters in range(200):
def closure():
optimizer.zero_grad()
pred = net(dummy_data)
dummy_onehot_label = F.softmax(dummy_label, dim=-1)
dummy_loss = criterion(pred, dummy_onehot_label)
dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
grad_diff = 0
grad_count = 0
for gx, gy in zip(dummy_dy_dx, original_dy_dx):
grad_diff += ((gx - gy) ** 2).sum()
grad_count += gx.nelement()
grad_diff.backward()
return grad_diff
optimizer.step(closure)
# Display progress every 10 iterations
if iters % 10 == 0:
current_loss = closure()
print(iters, " ", "%.4f" % current_loss.item(), "%")
ax_iter.append(iters)
ax_loss.append(current_loss)
history.append(tt(dummy_data[0].cpu()))
# Display the data loss value per iteration on a graph
plt.plot(torch.tensor(ax_iter), torch.tensor(ax_loss))
plt.xlabel("Iterations")
plt.ylabel("Data Loss")
plt.show()
# Display the development of iterative image re-construction
# Plot generated image from every 10th iteration
plt.figure(figsize=(12, 3))
for i in range(20):
plt.subplot(2, 10, i + 1)
plt.imshow(history[i * 10])
plt.title("iter=%d" % (i * 10))
plt.axis('off')
# Display dummy label
print("Dummy label is %d." % torch.argmax(dummy_label, dim=-1).item())