posteriors do not include distribution for actual data, how to estimate/compute the distribution for y? #135
-
Hi there, I am following the infinite stream example and, as I got to understand, it plots the distributions of the hidden state however I rather would like to plot the distributions of the data Adding a var # Noisy observation
y_rv ~ Normal(mean = x_current, precision = τ)
y = datavar(Float64)
y ~ y_rv does not work as the q factorization somehow is broken thereby. Hence there must be some other standard way how to get the distribution of datavars like Every help is highly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 6 replies
-
@bvdmitri can you help? |
Beta Was this translation helpful? Give feedback.
-
Hey @schlichtanders Solution 1: It's best illustrated by an example, simple random walk model. # test #3 (array + single entry for predictvars)
data = (y = [1.0, -1.0, 0.9, missing], )
@model function example_model_3(n)
x = randomvar(n + 1)
y = datavar(Float64, n) where { allow_missing = true }
z ~ NormalMeanPrecision(0.0, 1.0)
z_prev = z
for i in 1:n
x[i] ~ NormalMeanPrecision(z_prev, 1.0)
y[i] ~ NormalMeanPrecision(x[i], 1.0)
z_prev = x[i]
end
o = datavar(Float64) where { allow_missing = true }
x[n+1] ~ NormalMeanPrecision(x[n], 1.0)
o ~ NormalMeanPrecision(x[n+1], 1.0)
end
neresult = inference(model = example_model_3(length(data[:y])), iterations=10, data = data, predictvars=(o = KeepLast(), )) The problem with this solution is that it doesn't work with Solution 2: Solution 3 |
Beta Was this translation helpful? Give feedback.
-
@schlichtanders That is a nice question and @albertpod has indeed worked on a solution for that in a separate bracnh. Currently, the user-friendly solution is not available and has not been tested properly. We are working on it. As a workaround, you can manually extract this information like the following: For demonstration purposes, I will use the same infinite stream example.
@model function kalman_filter()
# Prior for the previous state
x_prev_mean = datavar(Float64)
x_prev_var = datavar(Float64)
x_prev ~ Normal(mean = x_prev_mean, variance = x_prev_var)
# Prior for the observation noise
τ_shape = datavar(Float64)
τ_rate = datavar(Float64)
τ ~ Gamma(shape = τ_shape, rate = τ_rate)
# Random walk with fixed precision
x_current ~ Normal(mean = x_prev, precision = 1.0)
# Noisy observation
y = datavar(Float64)
# [CHANGED HERE]
# `~` tilde operator optionally returns a handler for the node in the factor graph
ynode, y ~ Normal(mean = x_current, precision = τ)
# [CHANGED HERE]
# added return statement, we can access the result of this later
return ynode
end
# Create a custom event listener, read documentation for the `rxinference` for more details (section Events)
# Long story short, after each tick of the inference engine
# The event listener function will manually extract the predictive distribution from the `return` statement in the
# `@model` macro. The event listener then subscribes to the latest available predictive distribution and saves it to an array
function create_event_listener(engine, predictive_distributions)
return (event::RxInferenceEvent{ :on_tick }) -> begin
# `engine.returnval` returns `...` from the `return ...` statement in the `@model` macro
ynode = engine.returnval
# Some internal black magic, get an observable for a message towards data variable `y`
# That is already implemented in the `dev-predict` branch, but is not compatible with the `rxinference` function (yet)
message = ReactiveMP.messageout(ReactiveMP.getinterface(ynode, :out))
# Materialize the message to an actual distribution (calls analytical rules under the hood)
predictive = message |> map(Any, (update) -> ReactiveMP.getdata(ReactiveMP.as_message(update)))
# Subscribe to the observable and save the result in the `predictive_distributions` array
# No need to unsubscribe, as we only `take(1)`. In ReactiveMP it will return the last available estimate
subscribe!(predictive |> take(1), (distribution) -> push!(predictive_distributions, distribution))
end
end
function run_static(environment, datastream)
# `@autoupdates` structure specifies how to update our priors based on new posteriors
# For example, every time we have updated a posterior over `x_current` we update our priors
# over `x_prev`
autoupdates = @autoupdates begin
x_prev_mean, x_prev_var = mean_var(q(x_current))
τ_shape = shape(q(τ))
τ_rate = rate(q(τ))
end
predictive_distributions = []
engine = rxinference(
model = kalman_filter(),
constraints = filter_constraints(),
datastream = datastream,
autoupdates = autoupdates,
returnvars = (:x_current, ),
keephistory = 10_000,
historyvars = (x_current = KeepLast(), τ = KeepLast()),
initmarginals = (x_current = NormalMeanVariance(0.0, 1e3), τ = GammaShapeRate(1.0, 1.0)),
iterations = 10,
free_energy = true,
autostart = false,
events = Val((:on_tick, ))
)
subscription = subscribe!(engine.events, create_event_listener(engine, predictive_distributions))
RxInfer.start(engine)
# Since we are using the static dataset it is safe to unsubscribe right after starting the engine
# For asynchronous datasets, the unsubscribe should happen later
unsubscribe!(subscription)
return engine, predictive_distributions
end
result, predictives = run_static(static_environment, static_datastream);
static_inference = @gif for i in 1:n
estimated = result.history[:x_current]
p = plot(1:i, mean.(estimated[1:i]), ribbon = var.(estimated[1:n]), label = "Estimation")
p = plot!(p, 1:i, mean.(predictives[1:i]), ribbon = var.(predictives[1:n]), label = "Predictive")
p = plot!(static_history[1:i], label = "Real states")
p = scatter!(static_observations[1:i], ms = 2, label = "Observations")
p = plot(p, size = (1000, 300), legend = :bottomright)
end |
Beta Was this translation helpful? Give feedback.
Ah, I think you are right. The confidence interval or "ribbons" in the demo should indeed use
std
(or3std
) instead ofvar
.