Skip to content

Commit

Permalink
Don't include lhs of := in results of predict()
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 21, 2024
1 parent 6657441 commit bc74e8c
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 61 deletions.
2 changes: 1 addition & 1 deletion ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function DynamicPPL.predict(
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
vcat,
Expand Down
27 changes: 17 additions & 10 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,26 @@ $(TYPEDFIELDS)
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
context::C
end

ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
function ValuesAsInModelContext(context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), context)
# If child context is not passed
ValuesAsInModelContext(values, include_colon_eq) = ValuesAsInModelContext(values, include_colon_eq, DefaultContext())
# If values are not passed
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, child)
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
end

is_extracting_values(context::ValuesAsInModelContext) = true
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
function is_extracting_values(context::AbstractContext)
return is_extracting_values(NodeTrait(context), context)
end
Expand Down Expand Up @@ -114,8 +118,8 @@ function dot_tilde_assume(
end

"""
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
Get the values of `varinfo` as they would be seen in the model.
Expand All @@ -132,6 +136,7 @@ of additional model evaluations.
# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
Expand Down Expand Up @@ -183,24 +188,26 @@ false
julia> # Approach 2: Extract realizations using `values_as_in_model`.
# (✓) `values_as_in_model` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function values_as_in_model(
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(context)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
end
11 changes: 9 additions & 2 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,17 @@ module Issue537 end
@test haskey(varinfo, @varname(x))
@test !haskey(varinfo, @varname(y))

# While `values_as_in_model` should contain both `x` and `y`.
values = values_as_in_model(model, deepcopy(varinfo))
# While `values_as_in_model` should contain both `x` and `y`, if
# include_colon_eq is set to `true`.
values = values_as_in_model(model, true, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test haskey(values, @varname(y))

# And if include_colon_eq is set to `false`, then `values` should
# only contain `x`.
values = values_as_in_model(model, false, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test !haskey(values, @varname(y))
end
end

Expand Down
112 changes: 64 additions & 48 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
samples = (; samples_dict...)
samples = modify_value_representation(samples) # `modify_value_representation` defined in test/test_util.jl
@test logpriors[i]
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
DynamicPPL.TestUtils.logprior_true(model, samples[:s], samples[:m])
@test loglikelihoods[i] DynamicPPL.TestUtils.loglikelihood_true(
model, samples[:s], samples[:m]
)
@test logjoints[i]
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m])
end
end
end
Expand Down Expand Up @@ -283,10 +283,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
# Ensure log-probability computations are implemented.
@test logprior(model, x) DynamicPPL.TestUtils.logprior_true(model, x...)
@test loglikelihood(model, x)
DynamicPPL.TestUtils.loglikelihood_true(model, x...)
DynamicPPL.TestUtils.loglikelihood_true(model, x...)
@test logjoint(model, x) DynamicPPL.TestUtils.logjoint_true(model, x...)
@test logjoint(model, x) !=
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...)
DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...)
# Ensure `varnames` is implemented.
vi = last(
DynamicPPL.evaluate!!(
Expand Down Expand Up @@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
realizations = values_as_in_model(model, varinfo)
# We can set the include_colon_eq arg to false because none of
# the demo models contain :=. The behaviour when
# include_colon_eq is true is tested in test/compiler.jl
realizations = values_as_in_model(model, false, varinfo)
# Ensure that all variables are found.
vns_found = collect(keys(realizations))
@test vns vns_found == vns vns_found
Expand Down Expand Up @@ -432,72 +435,85 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()

@testset "predict" begin
@testset "with MCMCChains.Chains" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal* x[i], σ)
end
# Insert a := block to test that it is not included in predictions
σ2 := σ^2
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end

# Construct a chain with 'sampled values' of β
ground_truth_β = 2
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [])

# Generate predictions from that chain
xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01

# Ensure that `rng` is respected
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
# Also test a vectorized model
@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end
m_lin_reg_test_vec = linear_reg_vec(xs_test, missing)

# Multiple chains
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)
@testset "variables in chain" begin
# Note that this also checks that variables on the lhs of :=,
# such as σ2, are not included in the resulting chain
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
end

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@testset "accuracy" begin
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
@testset "ensure that rng is respected" begin
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))
end

@testset "accuracy on vectorized model" begin
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
end

@testset "prediction from multiple chains" begin
# Normal linreg model
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), []
)
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ground_truth_β * xs_test[2] atol = 0.01
end

# Vectorized linreg model
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred_vec[1] ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ground_truth_β * xs_test[2] atol = 0.01
end
end
end

@testset "with AbstractVector{<:AbstractVarInfo}" begin
Expand All @@ -524,7 +540,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()

@test size(predicted_vis) == size(chain)
@test Set(keys(predicted_vis[1])) ==
Set([@varname(β), @varname(y[1]), @varname(y[2])])
Set([@varname(β), @varname(y[1]), @varname(y[2])])
# because β samples are from the prior, the std will be larger
@test mean([
predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis)
Expand Down

0 comments on commit bc74e8c

Please sign in to comment.