-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrelevanceCalculator.py
129 lines (79 loc) · 3.37 KB
/
relevanceCalculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# hacked from https://github.com/hila-chefer/Transformer-MM-Explainability
import torch
import numpy as np
# rule 6 from paper
def apply_self_attention_rules(R_ss, cam_ss):
R_ss_addition = torch.matmul(cam_ss, R_ss)
return R_ss_addition
def normalizeCam(cam, nW):
cam = cam / torch.max(cam)#torch.max(cam[:nW, :nW])
return cam
def normalizeR(R):
R_ = R - torch.eye(R.shape[0])
R_ /= R.sum(dim=1, keepdim=True)
return R_ + torch.eye(R.shape[0])
def generate_relevance_(model, input, index=None):
device = next(model.parameters()).device
output, cls = model(input, analysis=True)
if index == None:
index = np.argmax(output.cpu().data.numpy(), axis=-1)
# accumulate gradients on attentions
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
one_hot[0, index] = 1
one_hot = torch.from_numpy(one_hot)
one_hot = torch.sum(one_hot.to(device) * output)
model.zero_grad()
one_hot.backward()
# construct the initial relevance matrix
shiftSize = model.shiftSize
windowSize = model.hyperParams.windowSize
T = input.shape[-1] # number of bold tokens
dynamicLength = ((T - windowSize) // shiftSize) * shiftSize + windowSize
nW = model.last_numberOfWindows # number of cls tokens
num_tokens = dynamicLength + nW
R = torch.eye(num_tokens, num_tokens)
# now pass the relevance matrix through the blocks
for block in model.blocks:
cam = block.getJuiceFlow().cpu()
R += apply_self_attention_rules(R, cam)
R = normalizeR(R)
del one_hot
del output
torch.cuda.empty_cache()
# R.shape = (dynamicLength + nW, dynamicLength + nW)
# get the part that the window cls tokens are interested in
# here we have relevance of each window cls token to the bold tokens
inputToken_relevances = R[:nW, nW:] # of shape (nW, dynamicLength)
return inputToken_relevances # (nW, dynamicLength)
def generate_relevance(model, input, index=None):
device = next(model.parameters()).device
output, cls = model(input, analysis=True)
if index == None:
index = np.argmax(output.cpu().data.numpy(), axis=-1)
# accumulate gradients on attentions
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
one_hot[0, index] = 1
one_hot = torch.from_numpy(one_hot)
one_hot = torch.sum(one_hot.to(device) * output)
model.zero_grad()
one_hot.backward()
# construct the initial relevance matrix
shiftSize = model.shiftSize
windowSize = model.hyperParams.windowSize
T = input.shape[-1] # number of bold tokens
dynamicLength = ((T - windowSize) // shiftSize) * shiftSize + windowSize
nW = model.last_numberOfWindows # number of cls tokens
num_tokens = dynamicLength + nW
R = torch.eye(num_tokens, num_tokens)
# now pass the relevance matrix through the blocks
for block in model.blocks:
cam = block.getJuiceFlow().cpu()
R += apply_self_attention_rules(R, cam)
del one_hot
del output
torch.cuda.empty_cache()
# R.shape = (dynamicLength + nW, dynamicLength + nW)
# get the part that the window cls tokens are interested in
# here we have relevance of each window cls token to the bold tokens
inputToken_relevances = R[:nW, nW:] # of shape (nW, dynamicLength)
return inputToken_relevances # (nW, dynamicLength)