forked from yagguc/deep_impression
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathaudiovisual_stream.py
27 lines (21 loc) · 974 Bytes
/
audiovisual_stream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import auditory_stream
import chainer
import visual_stream
### MODEL ###
class ResNet18(chainer.Chain):
def __init__(self):
super(ResNet18, self).__init__(
aud=auditory_stream.ResNet18(),
vis=visual_stream.ResNet18(),
fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal())
)
def __call__(self, x):
h = [self.aud(chainer.Variable(chainer.cuda.to_gpu(x[0]))), chainer.functions.expand_dims(
chainer.functions.sum(self.vis(chainer.Variable(chainer.cuda.to_gpu(x[1][:256]))), 0), 0)]
for i in range(256, x[1].shape[0], 256):
h[1] += chainer.functions.expand_dims(
chainer.functions.sum(self.vis(chainer.Variable(chainer.cuda.to_gpu(x[1][i: i + 256]))), 0),
0)
h[1] /= x[1].shape[0]
return chainer.cuda.to_cpu(((chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2).data[0])
### MODEL ###