Skip to content

Commit

Permalink
more tests for arbitrary logpdf
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 11, 2023
1 parent fb4d2f0 commit 83b7d15
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 61 deletions.
4 changes: 2 additions & 2 deletions src/densities/continouslogpdf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function BayesBase.convert_paramfloattype(
end

function BayesBase.vague(::Type{<:ContinuousUnivariateLogPdf})
return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), (x) -> 1)
return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), (x) -> 0)
end

# We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same
Expand Down Expand Up @@ -211,7 +211,7 @@ function BayesBase.convert_paramfloattype(
end

function BayesBase.vague(::Type{<:ContinuousMultivariateLogPdf}, dims::Int)
return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, (x) -> 1)
return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, (x) -> 0)
end

# We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same
Expand Down
141 changes: 82 additions & 59 deletions test/densities/continouslogpdf_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ end
d = vague(ContinuousUnivariateLogPdf)

@test typeof(d) <: ContinuousUnivariateLogPdf
@test d(rand()) 1
@test d(rand()) 0
end

@testitem "ContinuousUnivariateLogPdf: prod" begin
Expand Down Expand Up @@ -230,31 +230,32 @@ end
end

@testitem "ContinuousUnivariateLogPdf: convert" begin
d = DomainSets.FullSpace()
import DomainSets: FullSpace

d = FullSpace()
l = (x) -> 1.0

c = convert(ContinuousUnivariateLogPdf, d, l)
@test typeof(c) <: ContinuousUnivariateLogPdf
@test isapprox(c, ContinuousUnivariateLogPdf(d, l), atol=1e-12)

c2 = convert(ContinuousUnivariateLogPdf, c)
@test typeof(c2) <: ContinuousUnivariateLogPdf
@test isapprox(c2, ContinuousUnivariateLogPdf(d, l), atol=1e-12)
end

@testitem "ContinuousMultivariateLogPdf: Constructor" begin
import DomainSets: FullSpace

f = (x) -> -x'x
dist = ContinuousMultivariateLogPdf(2, f)
d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f)
d2 = ContinuousMultivariateLogPdf(FullSpace()^2, f)

@test typeof(dist) === typeof(d2)
@test dist d2
@test paramfloattype(dist) === Float64
@test samplefloattype(dist) === Float64
@test paramfloattype(d2) === Float64
@test samplefloattype(d2) === Float64

@test_throws AssertionError ContinuousMultivariateLogPdf(DomainSets.FullSpace(), f)
@test_throws AssertionError ContinuousMultivariateLogPdf(FullSpace(), f)
@test_throws MethodError ContinuousMultivariateLogPdf(f)
end

Expand All @@ -272,7 +273,9 @@ end
end

@testitem "ContinuousMultivariateLogPdf: pdf/logpdf" begin
dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> -x'x)
import DomainSets: FullSpace, HalfLine

dist = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> -x'x)

f32_points1 = range(Float32(-10.0), Float32(10.0); length=5)
f64_points1 = range(-10.0, 10.0; length=5)
Expand All @@ -288,7 +291,7 @@ end
@test all(map(p -> -p'p == logpdf(dist, p), points1))
@test all(map(p -> exp(-p'p) == pdf(dist, p), points1))

d2 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> -x'x / 4)
d2 = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> -x'x / 4)

f32_points2 = range(Float32(0.0), Float32(10.0); length=5)
f64_points2 = range(0.0, 10.0; length=5)
Expand All @@ -306,9 +309,11 @@ end
end

@testitem "ContinuousMultivariateLogPdf: test domain in logpdf" begin
import DomainSets: FullSpace, HalfLine

for dim in (2, 3, 4)
dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dim, (x) -> -x'x)
d2 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^dim, (x) -> -x'x)
dist = ContinuousMultivariateLogPdf(FullSpace()^dim, (x) -> -x'x)
d2 = ContinuousMultivariateLogPdf(HalfLine()^dim, (x) -> -x'x)

# This also throws a warning in stdout
@test_logs (:warn, r".*incompatible combination.*") @test_throws AssertionError logpdf(
Expand All @@ -321,89 +326,107 @@ end
end

@testitem "ContinuousMultivariateLogPdf: vague" begin
d = vague(ContinuousMultivariateLogPdf, 2)
for k in 2:5
d = vague(ContinuousMultivariateLogPdf, k)

@test typeof(d) <: ContinuousMultivariateLogPdf
@test d ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0)
@test typeof(d) <: ContinuousMultivariateLogPdf
@test d(rand(k)) 0
end
end

@testitem "ContinuousMultivariateLogPdf: prod" begin
dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0 * -x'x)
d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 3.0 * -x'x)
import DomainSets: FullSpace, HalfLine

product = prod(ProdAnalytical(), dist, d2)
dist = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 2.0 * -x'x)
d2 = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 3.0 * -x'x)

pr1 = prod(GenericProd(), dist, d2)
pt1 = ContinuousMultivariateLogPdf(
DomainSets.FullSpace()^2, (x) -> logpdf(dist, x) + logpdf(d2, x)
FullSpace()^2, (x) -> logpdf(dist, x) + logpdf(d2, x)
)

@test getdomain(product) === getdomain(dist)
@test getdomain(product) === getdomain(d2)
@test variate_form(typeof(product)) === variate_form(typeof(dist))
@test variate_form(typeof(product)) === variate_form(typeof(d2))
@test value_support(typeof(product)) === value_support(typeof(dist))
@test value_support(typeof(product)) === value_support(typeof(d2))
@test support(product) === support(dist)
@test support(product) === support(d2)
@test isapprox(product, pt1, atol=1e-12)
@test variate_form(typeof(pr1)) === variate_form(typeof(dist))
@test variate_form(typeof(pr1)) === variate_form(typeof(d2))
@test value_support(typeof(pr1)) === value_support(typeof(dist))
@test value_support(typeof(pr1)) === value_support(typeof(d2))
@test support(pr1) === support(dist)
@test support(pr1) === support(d2)

for x in [randn(2) for _ in 1:10]
@test isapprox(logpdf(pr1, x), logpdf(pt1, x))
@test isapprox(pdf(pr1, x), pdf(pt1, x))
end

result = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 2.0 * -x'x)
d4 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 3.0 * -x'x)
result = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> 2.0 * -x'x)
d4 = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> 3.0 * -x'x)

pr2 = prod(ProdAnalytical(), result, d4)
pr2 = prod(GenericProd(), result, d4)
pt2 = ContinuousMultivariateLogPdf(
DomainSets.HalfLine()^2, (x) -> logpdf(result, x) + logpdf(d4, x)
HalfLine()^2, (x) -> logpdf(result, x) + logpdf(d4, x)
)

@test getdomain(pr2) === getdomain(result)
@test getdomain(pr2) === getdomain(d4)
@test variate_form(typeof(pr2)) === variate_form(typeof(result))
@test variate_form(typeof(pr2)) === variate_form(typeof(d4))
@test value_support(typeof(pr2)) === value_support(typeof(result))
@test value_support(typeof(pr2)) === value_support(typeof(d4))
@test support(pr2) === support(result)
@test support(pr2) === support(d4)
@test isapprox(pr2, pt2, atol=1e-12)

@test !isapprox(product, pr2; atol=1e-12)
for x in [rand(2) for _ in 1:10]
@test isapprox(logpdf(pr2, x), logpdf(pt2, x))
@test isapprox(pdf(pr2, x), pdf(pt2, x))
end

d5 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0 * -x'x)
d6 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 2.0 * -x'x)
@test_throws AssertionError prod(ProdAnalytical(), d5, d6)
d5 = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 2.0 * -x'x)
d6 = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> 2.0 * -x'x)
@test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), [1.0, -1.0]) # domains are incompatible
end

@testitem "ContinuousMultivariateLogPdf: vectorised-prod" begin
import DomainSets: FullSpace, HalfLine

f = (x) -> 2.0 * -x'x
dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f)
d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f)
result = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> f(x) + f(x))
dist = ContinuousMultivariateLogPdf(FullSpace()^2, f)
result = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 3 * f(x))

product = prod(ProdAnalytical(), dist, d2)
product = prod(GenericProd(), prod(GenericProd(), dist, dist), dist)

@test product isa GenericLogPdfVectorisedProduct
@test getdomain(product) === getdomain(dist)
@test getdomain(product) === getdomain(d2)
@test product isa LinearizedProductOf
@test variate_form(typeof(product)) === variate_form(typeof(dist))
@test variate_form(typeof(product)) === variate_form(typeof(d2))
@test variate_form(typeof(product)) === variate_form(typeof(result))
@test value_support(typeof(product)) === value_support(typeof(dist))
@test value_support(typeof(product)) === value_support(typeof(d2))
@test value_support(typeof(product)) === value_support(typeof(result))
@test support(product) === support(dist)
@test support(product) === support(d2)
@test support(product) === support(result)

for x in [rand(2) for _ in 1:10]
@test pdf(product, x) pdf(result, x)
@test logpdf(product, x) logpdf(result, x)
end

for point in [rand(Float64, 2) for _ in 1:10]
@test pdf(product, point) pdf(result, point)
@test logpdf(product, point) logpdf(result, point)
# Test internal side-effects
another_product = prod(GenericProd(), product, dist)

for x in [rand(2) for _ in 1:10]
@test logpdf(product, x) logpdf(result, x)
@test pdf(product, x) pdf(result, x)

@test logpdf(another_product, x) (logpdf(product, x) + logpdf(dist, x))
@test pdf(another_product, x) (pdf(product, x) * pdf(dist, x))
end
end

@testitem "ContinuousMultivariateLogPdf: convert" begin
d = DomainSets.FullSpace()^2
l = (x) -> 1.0
import DomainSets: FullSpace, HalfLine

for k in 2:5
d = FullSpace()^k
l = (x) -> 0

c = convert(ContinuousMultivariateLogPdf, d, l)
@test typeof(c) <: ContinuousMultivariateLogPdf
@test isapprox(c, ContinuousMultivariateLogPdf(d, l), atol=1e-12)
c = convert(ContinuousMultivariateLogPdf, d, l)
@test typeof(c) <: ContinuousMultivariateLogPdf

c2 = convert(ContinuousMultivariateLogPdf, c)
@test typeof(c2) <: ContinuousMultivariateLogPdf
@test isapprox(c2, ContinuousMultivariateLogPdf(d, l), atol=1e-12)
c2 = convert(ContinuousMultivariateLogPdf, c)
@test typeof(c2) <: ContinuousMultivariateLogPdf
end
end

0 comments on commit 83b7d15

Please sign in to comment.