-
Hello, I'm trying to learn RxInfer.jl and variational message passing in general, so please bare with me if what I'm trying to do is not possible. Any help would be greatly appreciated. I am trying to implement this model My naive implementation with RxInfer.jl is @model function my_model(obs, N, sigma)
local p
for i = 1:N
p[i] ~ Beta(1, 1)
end
local x
for i = 1:N
x[i] ~ Bernoulli(p[i])
end
local C
for i = 1:N
C ~ C + x[i]
end
obs ~ NormalMeanVariance(C, sigma^2)
end and I was able to plot the model graph with conditioned = my_model(N=3, sigma=0.1) | (obs=3,)
conditioned = RxInfer.create_model(conditioned)
graph = GraphPlot.gplot(RxInfer.getmodel(conditioned)) so the graph looks fine. When I run results = infer(
model=my_model(N=3, sigma=0.1),
data=(obs=2.0,),
) I get the error
So I go to the docs and I see that one needs to initialize messages if the model has loops. But if I look at the graph, I do not see any loops? Despite that, I try results = infer(
model=my_model(N=3, sigma=0.1),
data=(obs=2.0,),
initialization=@initialization(μ(C) = PointMass(0)),
) and now I'm hit with: `ProductOf` object cannot be used as a functional form in inference backend. Use form constraints to restrict the functional form of marginal posteriors. and I am not exactly sure how to go from here. My guess is that I would need to define a custom node for the operation that aggregates all the Bernoulli. My first thought was to implement the Poisson Binomial distribution as a new node in RxInfer.jl following this tutorial but I'm not 100% sure if this is the way to go so wanted to check for some guidance. Thank you very much! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
It sounds like there's an issue with the model definition where an expression like The current representation in the printed graph might also be misleading, as the A potential workaround could be: accum_C = x[1]
for i = 2:N
next_C[i - 1] ~ accum_C + x[i]
accum_C = next_C[i - 1]
end
obs ~ NormalMeanVariance(accum_C, sigma^2) This adjustment avoids the problematic expression and maintains the intended behavior of creating a sum of x[i]. Note, however, that (as far as I remember) we do not have summation rules for @wouterwln, it appears we need to address this in P.S. Opened an issue |
Beta Was this translation helpful? Give feedback.
-
One possible solution to your issue could be as follows: add(x...) = sum(x)
@model function my_model(obs, N, sigma)
local p
for i = 1:N
# I changed the prior to non-flat version
p[i] ~ Beta(2, 2)
end
local x
for i = 1:N
x[i] ~ Bernoulli(p[i])
end
# So I use the custom `add` function
C ~ add(in = x)
# I also added more iterations to get better posteriors
for i in eachindex(obs)
obs[i] ~ NormalMeanVariance(C, sigma^2)
end
end
# We enable `CVI` as an approximation method for the `add` function
@meta function my_meta(rng, nr_samples, nr_iterations, optimizer)
add() -> CVI(rng, nr_samples, nr_iterations, optimizer, ForwardDiffGrad(), 10, Val(false))
end
# We write initialization, that is required for `CVI`, but not for the model structure
@initialization function my_initialization()
q(C) = NormalMeanVariance(0, 100)
q(p) = Beta(2, 2)
end Lets generate some dataset # Let's pretend all our observations are just `3`. And we have a `500` observations
dataset = fill(3, 500) When I run the inference results = infer(
model = my_model(N=3, sigma=0.1, ),
data = (obs = dataset, ),
constraints = MeanField(), # required for `CVI`, enables VI
iterations = 10, # VI iterations
initialization=my_initialization(),
meta = my_meta(StableRNG(42), 1000, 1000, Optimisers.Descent(1e-7))
) I get the following result: mean.(results.posteriors[:p]) # [0.599686, 0.599614, 0.599845]
mean.(results.posteriors[:x]) # [ 0.998428, 0.99807, 0.999225 ]
mean(results.posteriors[:C].argument) # 2.998 These results seem plausible given the same priors on |
Beta Was this translation helpful? Give feedback.
-
@arnauqb I'd like to build on the response from @bvdmitri. To make the most of our inference engine, consider creating your own rules (as suggested by @bvdmitri). While we plan to introduce more methods for simplifying message approximation, it's important to note that each method has its drawbacks. We've been discussing efficient inference strategies for this model with @ismailsenoz, @Nimrais, and @wouterwln. I hope they can share their insights when possible. In the meantime, please don't hesitate to ask questions here or start new discussions! Cheers! |
Beta Was this translation helpful? Give feedback.
One possible solution to your issue could be as follows: