-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils_my_original.py
98 lines (76 loc) · 3.24 KB
/
utils_my_original.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn import functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SeqToANNContainer(nn.Module):
# This code is form spikingjelly https://github.com/fangwei123456/spikingjelly
def __init__(self, *args):
super().__init__()
if len(args) == 1:
self.module = args[0]
else:
self.module = nn.Sequential(*args)
def forward(self, x_seq: torch.Tensor):
y_shape = [x_seq.shape[0], x_seq.shape[1]]
#print(x_seq.flatten(0, 1).contiguous().shape)
y_seq = self.module(x_seq.flatten(0, 1).contiguous())
y_shape.extend(y_seq.shape[1:])
return y_seq.view(y_shape)
class tdLayer(nn.Module):
def __init__(self, layer):
super(tdLayer, self).__init__()
self.layer = SeqToANNContainer(layer)
def forward(self, x):
x_ = self.layer(x)
return x_
class tdBatchNorm(nn.Module):
def __init__(self, out_panel):
super(tdBatchNorm, self).__init__()
self.bn = nn.BatchNorm2d(out_panel)
self.seqbn = SeqToANNContainer(self.bn)
def forward(self, x):
y = self.seqbn(x)
return y
def replace_layer_by_tdlayer(model):
for name, module in model._modules.items():
if hasattr(module, "_modules"):
model._modules[name] = replace_layer_by_tdlayer(module)
if module.__class__.__name__ == 'Conv2d':
model._modules[name] = tdLayer(model._modules[name])
if module.__class__.__name__ == 'Linear':
model._modules[name] = tdLayer(model._modules[name])
if module.__class__.__name__ == 'BatchNorm2d':
model._modules[name] = tdLayer(model._modules[name])
if module.__class__.__name__ == 'AvgPool2d':
model._modules[name] = tdLayer(model._modules[name])
if module.__class__.__name__ == 'Flatten':
model._modules[name] = nn.Flatten(start_dim=-3,end_dim=-1)
if module.__class__.__name__ == 'Dropout':
model._modules[name] = tdLayer(model._modules[name])
if module.__class__.__name__ == 'AdaptiveAvgPool2d':
model._modules[name] = tdLayer(model._modules[name])
# if module.__class__.__name__ == 'AdaptiveAvgPool2d':
# model._modules[name] = tdLayer(model._modules[name])
return model
def isActivation(name):
if 'spike_layer' in name.lower() :
return True
return False
def replace_maxpool2d_by_avgpool2d(model):
for name, module in model._modules.items():
if hasattr(module, "_modules"):
model._modules[name] = replace_maxpool2d_by_avgpool2d(module)
if module.__class__.__name__ == 'MaxPool2d':
model._modules[name] = nn.AvgPool2d(kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding)
return model
def add_dimension(x, T):
x.unsqueeze_(1)
x = x.repeat(1, T, 1, 1, 1)
return x
def isActivation_spike(name):
if 'spike_layer' in name.lower():
return True
return False