-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple_demo_all_public.py
177 lines (122 loc) · 4.59 KB
/
simple_demo_all_public.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
try:
# install ezkl
import google.colab
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])
# rely on local installation of ezkl if the notebook is not in colab
except:
pass
# here we create and (potentially train a model)
# make sure you have the dependencies required here already installed
from torch import nn
import ezkl
import os
import json
import torch
# Defines the model
# we got convs, we got relu, we got linear layers
# What else could one want ????
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)
self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)
self.relu = nn.ReLU()
self.d1 = nn.Linear(48, 48)
self.d2 = nn.Linear(48, 10)
def forward(self, x):
# 32x1x28x28 => 32x32x26x26
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
# flatten => 32 x (32*26*26)
x = x.flatten(start_dim = 1)
# 32 x (32*26*26) => 32x128
x = self.d1(x)
x = self.relu(x)
# logits => 32x10
logits = self.d2(x)
return logits
circuit = MyModel()
# Train the model as you like here (skipped for brevity)
model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')
shape = [1, 28, 28]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*torch.rand(1,*shape, requires_grad=True)
# Flips the neural net into inference mode
circuit.eval()
# Export the model
torch.onnx.export(circuit, # model being run
x, # model input (or a tuple for multiple inputs)
model_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(input_data = [data_array])
# Serialize data into file:
json.dump( data, open(data_path, 'w' ))
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed" # "fixed" for params means that the committed to params are used for all proofs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
assert res == True
cal_path = os.path.join("calibration.json")
data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()
data = dict(input_data = [data_array])
# Serialize data into file:
json.dump(data, open(cal_path, 'w'))
await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True
# srs path
res = ezkl.get_srs( settings_path)
# now generate the witness file
res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)
# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK
res = ezkl.setup(
compiled_model_path,
vk_path,
pk_path,
)
assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)
# GENERATE A PROOF
proof_path = os.path.join('test.pf')
res = ezkl.prove(
witness_path,
compiled_model_path,
pk_path,
proof_path,
"single",
)
print(res)
assert os.path.isfile(proof_path)
# VERIFY IT
res = ezkl.verify(
proof_path,
settings_path,
vk_path,
)
assert res == True
print("verified")