diff --git a/Project.toml b/Project.toml index ffc57b2..2f44d8e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GPLikelihoods" uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40" authors = ["JuliaGaussianProcesses Team"] -version = "0.4.3" +version = "0.4.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/expectations.jl b/src/expectations.jl index 3de986b..85db8c3 100644 --- a/src/expectations.jl +++ b/src/expectations.jl @@ -87,25 +87,13 @@ function expected_loglikelihood( # Compute the expectation via Gauss-Hermite quadrature # using a reparameterisation by change of variable # (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) - return sum(Broadcast.instantiate( - Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair - # of marginal distribution q(fᵢ) and observation yᵢ - expected_loglikelihood(gh, lik, q_fᵢ, yᵢ) - end, - )) -end - -# Compute the expected_loglikelihood for one observation and a marginal distributions -function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y) - μ = mean(q_f) - σ̃ = sqrt2 * std(q_f) - return invsqrtπ * sum(Broadcast.instantiate( - Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every - # pair of Gauss-Hermite point x with weight w - f = σ̃ * x + μ - loglikelihood(lik(f), y) * w - end, - )) + # PR #90 introduces eager instead of lazy broadcast over observations + # and Gauss-Hermit points and weights in order to make the function + # type stable. Compared to other type stable implementations, e.g. + # using a custom two-argument pairwise sum, this is faster to + # differentiate using Zygote. + A = loglikelihood.(lik.(sqrt2 .* std.(q_f) .* gh.xs' .+ mean.(q_f)), y) .* gh.ws' + return invsqrtπ * sum(A) end function expected_loglikelihood( diff --git a/src/likelihoods/negativebinomial.jl b/src/likelihoods/negativebinomial.jl index 1c2ea40..f663d4c 100644 --- a/src/likelihoods/negativebinomial.jl +++ b/src/likelihoods/negativebinomial.jl @@ -147,9 +147,11 @@ struct NBParamII{T} <: NBParamMean end function (l::NegativeBinomialLikelihood{<:NBParamII})(f::Real) - μ = l.invlink(f) - ev = l.params.α * μ - return NegativeBinomial(_nb_mean_excessvar_to_r_p(μ, ev)...) + # Simplify parameter conversions and avoid splatting + α = l.params.α + r = inv(α) + p = inv(one(α) + α * l.invlink(f)) + return NegativeBinomial(r, p) end """ diff --git a/test/expectations.jl b/test/expectations.jl index 65fe96e..9063e7d 100644 --- a/test/expectations.jl +++ b/test/expectations.jl @@ -1,14 +1,17 @@ @testset "expectations" begin - # Test that the various methods of computing expectations return the same - # result. rng = MersenneTwister(123456) q_f = Normal.(zeros(10), ones(10)) likelihoods_to_test = [ + BernoulliLikelihood(), ExponentialLikelihood(), GammaLikelihood(), - PoissonLikelihood(), GaussianLikelihood(), + NegativeBinomialLikelihood(NBParamSuccess(1.0)), + NegativeBinomialLikelihood(NBParamFailure(1.0)), + NegativeBinomialLikelihood(NBParamI(1.0)), + NegativeBinomialLikelihood(NBParamII(1.0)), + PoissonLikelihood(), ] @testset "testing all analytic implementations" begin @@ -30,30 +33,50 @@ end end - @testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test - methods = [ - GaussHermiteExpectation(100), - MonteCarloExpectation(1e7), - GPLikelihoods.DefaultExpectationMethod(), - ] - def = GPLikelihoods.default_expectation_method(lik) - if def isa GPLikelihoods.AnalyticExpectation - push!(methods, def) - end - y = rand.(rng, lik.(zeros(10))) + @testset "testing consistency of different expectation methods" begin + @testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test + # Test that the various methods of computing expectations return the same + # result. + methods = [ + GaussHermiteExpectation(100), + MonteCarloExpectation(1e7), + GPLikelihoods.DefaultExpectationMethod(), + ] + def = GPLikelihoods.default_expectation_method(lik) + if def isa GPLikelihoods.AnalyticExpectation + push!(methods, def) + end + y = rand.(rng, lik.(zeros(10))) - results = map(m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods) - @test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results) + results = map( + m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods + ) + @test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results) + end end - @test GPLikelihoods.expected_loglikelihood( - MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10) - ) isa Real - @test GPLikelihoods.expected_loglikelihood( - GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10) - ) isa Real - @test GPLikelihoods.default_expectation_method(θ -> Normal(0, θ)) isa - GaussHermiteExpectation + @testset "testing return types and type stability" begin + @test GPLikelihoods.expected_loglikelihood( + MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10) + ) isa Real + @test GPLikelihoods.expected_loglikelihood( + GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10) + ) isa Real + @test GPLikelihoods.default_expectation_method(θ -> Normal(0, θ)) isa + GaussHermiteExpectation + + @testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test + # Test that `expectec_loglikelihood` is type-stable + y = rand.(rng, lik.(zeros(10))) + for method in [ + MonteCarloExpectation(100), + GaussHermiteExpectation(100), + GPLikelihoods.DefaultExpectationMethod(), + ] + @test (@inferred expected_loglikelihood(method, lik, q_f, y)) isa Real + end + end + end # see https://github.com/JuliaGaussianProcesses/ApproximateGPs.jl/issues/82 @testset "testing Zygote compatibility with GaussHermiteExpectation" begin