Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass multivariate data as univariate (collected) input to function #246

Open
wouterwln opened this issue Jul 2, 2024 · 2 comments
Open

Comments

@wouterwln
Copy link
Member

Let's look at the following model:

function dot end

@model function foo(x, y)
    local w
    for i in 1:length(i)
        w[i] ~ Normal(0, 1)
        x[i] ~ Normal(0, 1)
    end
    y ~ dot(x, w)
end

Now x will be created as a vector of data variables because we call x[i] ~ .... However, when we pass it do dot, it is still a ProxyLabel with maycreate=True() since that is how we pass it to the model. This will under the hood call getorcreate! without an interface and hence throw ERROR: Variable x is already a vector variable in the model. Nasty bug and I don't really know how to fix it (yet).

@blolt
Copy link

blolt commented Jul 31, 2024

Ran into this recently when specifying a Poisson GLM:

@model function poisson_glm(X, y, n, m)
    local θ
    for j in 1:m
        θ[j] ~ Normal(0, 1)
    end

    for i in 1:n
        λ[i] = dot(X[i, :], θ)
        y[i] ~ Poisson(exp(λ[i]))
    end
end

Would achieve a nice generalization of this Turing.jl example.

I'm looking to pick up more on RxInfer's backend packages, so this one might make sense since I stumbled across it independently. Is there a design for GraphPPL I could reference?

Edit: Will start here: https://reactivebayes.github.io/GraphPPL.jl/stable/developers_guide/#Developers-guide

@wouterwln
Copy link
Member Author

Hi @blolt , thanks for checking this out! Indeed, the Developers Guide is the closest thing we have to a description of the design of GraphPPL. I'll try to give some additional pointers. GraphPPL is split in two (maybe 3 but for the sake of this argument let's keep it at 2) separate modules: A graph engine (containing code for the creation and manipulation of a probabilistic model represented as a factor graph) and a metaprogramming module (which transforms user code into code the graph engine can interpret). We're mainly interested in the graph engine part here.

(The first thing to note is that in GraphPPL, we don't represent the model as an FFG, but we have factor nodes and variable nodes, so the entire graph is bipartite)

Whenever we create a (sub)-model, we assume that we have all interfaces (arguments to the function) available. For the top level model, this is trivial, but for nested models, there might be some stuff we would have to create. Also, for data, as is in this case, we don't know if we would have to pass X as a matrix-variate RV, a vector of multivariate RVs or a matrix of univariate RVs. That's why we came up with a clever trick: Whenever you pass something to a node/submodel, we don't (yet) materialize the variable node, but we pass a ProxyLabel:

"""
ProxyLabel(name, index, proxied)
A label that proxies another label in a probabilistic graphical model.
The proxied objects must implement the `is_proxied(::Type) = True()`.
The proxy labels may spawn new variables in a model, if `maycreate` is set to `True()`.
"""
mutable struct ProxyLabel{P, I, M}
const name::Symbol
const proxied::P
const index::I
const maycreate::M
end

This maycreate field denotes if we are allowed to create a new variable node if we use the variable in the creation of an atomic factor node. Now, this field is handled incorrectly, since if we use a label with maycreate=True() on the right hand side of an equation (as I'm doing in y ~ dot(x, w) since x here is still a ProxyLabel because it is supplied outside of the model) it will still try to create x instead of fetching it from the existing variables. This is all very deep down in the meticulous detail of GraphPPLs design so it's okay if you do not understand any of this. If you are interested in the internal workings of GraphPPL, I would start by playing around with a simple GraphPPL.Model, build a model with the nested model functionalities, look at the GraphPPL.Context attached to this and get a feeling for what objects live where.

I think in the end the bug is here:

function proxylabel(::True, name::Symbol, proxied::ProxyLabel, index::Any, maycreate::Any)
return ProxyLabel(name, proxied, index, proxied.maycreate | maycreate)
end

Since proxied.maycreate | maycreate will return True() if X could be created, even though it should not be created in this specific instance. Hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants