-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBP.py
148 lines (124 loc) · 4.06 KB
/
BP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from itertools import product
from math import log, e
class BP:
# loopy belief propagation
def __init__(self, g, max_prod=False):
self.g = g
self.message = dict()
self.points = dict()
self.max_prod = max_prod
def init_points(self):
points = dict()
for rv in self.g.rvs:
if rv.value is None:
points[rv] = rv.domain.values
else:
points[rv] = (rv.value,)
return points
def message_rv_to_f(self, x, rv, f):
if rv.value is None:
res = 1
for nb in rv.nb:
if nb != f:
res *= self.message[(nb, rv)][x]
return res
else:
return 1
def message_f_to_rv(self, x, f, rv, max_prod=False):
res = 0
param = []
for nb in f.nb:
if nb == rv:
param.append((x,))
else:
param.append(self.points[nb])
for joint_x in product(*param):
m = 1
for i, rv_ in enumerate(f.nb):
if rv_ != rv:
m *= self.message[(rv_, f)][joint_x[i]]
if max_prod:
res = max(f.potential.get(joint_x) * m, res)
else:
res += f.potential.get(joint_x) * m
return res
@staticmethod
def normalize_message(message):
z = 0
for k, v in message.items():
z = z + v
if z == 0:
return
for k, v in message.items():
message[k] = v / z
def belief(self, x, rv):
b = 1
for nb in rv.nb:
b *= self.message[(nb, rv)][x]
return b
def factor_belief(self, x, f):
b = f.potential.get(x)
for i, nb in enumerate(f.nb):
b *= self.message_rv_to_f(x[i], nb, f)
return b
def map(self, rv):
return max(self.points[rv], key=lambda x: self.belief(x, rv))
def prob(self, rv):
p = dict()
for x in self.points[rv]:
p[x] = self.belief(x, rv)
self.normalize_message(p)
return p
def factor_prob(self, f):
p = dict()
param = tuple(map(lambda rv: self.points[rv], f.nb))
for x in product(*param):
p[x] = self.factor_belief(x, f)
self.normalize_message(p)
return p
def partition(self):
z = 0
rvs_p = dict()
for rv in self.g.rvs:
rvs_p[rv] = self.prob(rv)
for f in self.g.factors:
param = tuple(map(lambda rv: self.points[rv], f.nb))
f_p = self.factor_prob(f)
for joint_x in product(*param):
b = 1
for i, nb in enumerate(f.nb):
b *= rvs_p[nb][joint_x[i]]
f_b = f_p[joint_x]
if f_b != 0:
z += f_b * log(f.potential.get(joint_x) * b / f_b)
for rv in self.g.rvs:
for x in self.points[rv]:
b = rvs_p[rv][x]
z -= b * log(b)
return e ** z
def run(self, iteration=10):
self.points = self.init_points()
self.message = dict()
# init message
for f in self.g.factors:
for rv in f.nb:
m = {k: 1 for k in self.points[rv]}
self.message[(f, rv)] = m
# BP iteration
for i in range(iteration):
# message from rv to f
for rv in self.g.rvs:
for f in rv.nb:
m = dict()
for x in self.points[rv]:
m[x] = self.message_rv_to_f(x, rv, f)
self.normalize_message(m)
self.message[(rv, f)] = m
# message from f to rv
for f in self.g.factors:
for rv in f.nb:
m = dict()
for x in self.points[rv]:
m[x] = self.message_f_to_rv(x, f, rv, max_prod=self.max_prod)
self.normalize_message(m)
self.message[(f, rv)] = m