Skip to content

Commit

Permalink
Merge pull request #90 from simsurace/ss/type_stability2
Browse files Browse the repository at this point in the history
Some performance optimizations
  • Loading branch information
theogf authored Jul 8, 2022
2 parents e9b7da9 + 82f58e0 commit 1cb8238
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GPLikelihoods"
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
authors = ["JuliaGaussianProcesses Team"]
version = "0.4.3"
version = "0.4.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
26 changes: 7 additions & 19 deletions src/expectations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,13 @@ function expected_loglikelihood(
# Compute the expectation via Gauss-Hermite quadrature
# using a reparameterisation by change of variable
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)
return sum(Broadcast.instantiate(
Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair
# of marginal distribution q(fᵢ) and observation yᵢ
expected_loglikelihood(gh, lik, q_fᵢ, yᵢ)
end,
))
end

# Compute the expected_loglikelihood for one observation and a marginal distributions
function expected_loglikelihood(gh::GaussHermiteExpectation, lik, q_f::Normal, y)
μ = mean(q_f)
σ̃ = sqrt2 * std(q_f)
return invsqrtπ * sum(Broadcast.instantiate(
Broadcast.broadcasted(gh.xs, gh.ws) do x, w # Loop over every
# pair of Gauss-Hermite point x with weight w
f = σ̃ * x + μ
loglikelihood(lik(f), y) * w
end,
))
# PR #90 introduces eager instead of lazy broadcast over observations
# and Gauss-Hermit points and weights in order to make the function
# type stable. Compared to other type stable implementations, e.g.
# using a custom two-argument pairwise sum, this is faster to
# differentiate using Zygote.
A = loglikelihood.(lik.(sqrt2 .* std.(q_f) .* gh.xs' .+ mean.(q_f)), y) .* gh.ws'
return invsqrtπ * sum(A)
end

function expected_loglikelihood(
Expand Down
8 changes: 5 additions & 3 deletions src/likelihoods/negativebinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ struct NBParamII{T} <: NBParamMean
end

function (l::NegativeBinomialLikelihood{<:NBParamII})(f::Real)
μ = l.invlink(f)
ev = l.params.α * μ
return NegativeBinomial(_nb_mean_excessvar_to_r_p(μ, ev)...)
# Simplify parameter conversions and avoid splatting
α = l.params.α
r = inv(α)
p = inv(one(α) + α * l.invlink(f))
return NegativeBinomial(r, p)
end

"""
Expand Down
71 changes: 47 additions & 24 deletions test/expectations.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
@testset "expectations" begin
# Test that the various methods of computing expectations return the same
# result.
rng = MersenneTwister(123456)
q_f = Normal.(zeros(10), ones(10))

likelihoods_to_test = [
BernoulliLikelihood(),
ExponentialLikelihood(),
GammaLikelihood(),
PoissonLikelihood(),
GaussianLikelihood(),
NegativeBinomialLikelihood(NBParamSuccess(1.0)),
NegativeBinomialLikelihood(NBParamFailure(1.0)),
NegativeBinomialLikelihood(NBParamI(1.0)),
NegativeBinomialLikelihood(NBParamII(1.0)),
PoissonLikelihood(),
]

@testset "testing all analytic implementations" begin
Expand All @@ -30,30 +33,50 @@
end
end

@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = rand.(rng, lik.(zeros(10)))
@testset "testing consistency of different expectation methods" begin
@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
# Test that the various methods of computing expectations return the same
# result.
methods = [
GaussHermiteExpectation(100),
MonteCarloExpectation(1e7),
GPLikelihoods.DefaultExpectationMethod(),
]
def = GPLikelihoods.default_expectation_method(lik)
if def isa GPLikelihoods.AnalyticExpectation
push!(methods, def)
end
y = rand.(rng, lik.(zeros(10)))

results = map(m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
results = map(
m -> GPLikelihoods.expected_loglikelihood(m, lik, q_f, y), methods
)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
end
end

@test GPLikelihoods.expected_loglikelihood(
MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.expected_loglikelihood(
GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.default_expectation_method-> Normal(0, θ)) isa
GaussHermiteExpectation
@testset "testing return types and type stability" begin
@test GPLikelihoods.expected_loglikelihood(
MonteCarloExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.expected_loglikelihood(
GaussHermiteExpectation(1), GaussianLikelihood(), q_f, zeros(10)
) isa Real
@test GPLikelihoods.default_expectation_method-> Normal(0, θ)) isa
GaussHermiteExpectation

@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
# Test that `expectec_loglikelihood` is type-stable
y = rand.(rng, lik.(zeros(10)))
for method in [
MonteCarloExpectation(100),
GaussHermiteExpectation(100),
GPLikelihoods.DefaultExpectationMethod(),
]
@test (@inferred expected_loglikelihood(method, lik, q_f, y)) isa Real
end
end
end

# see https://github.com/JuliaGaussianProcesses/ApproximateGPs.jl/issues/82
@testset "testing Zygote compatibility with GaussHermiteExpectation" begin
Expand Down

2 comments on commit 1cb8238

@theogf
Copy link
Member Author

@theogf theogf commented on 1cb8238 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63882

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.4 -m "<description of version>" 1cb82383c6583bde573571f6a1d858e34c40aa10
git push origin v0.4.4

Please sign in to comment.