diff --git a/src/densities/continouslogpdf.jl b/src/densities/continouslogpdf.jl index f27d3ac..a1f99a6 100644 --- a/src/densities/continouslogpdf.jl +++ b/src/densities/continouslogpdf.jl @@ -136,7 +136,7 @@ function BayesBase.convert_paramfloattype( end function BayesBase.vague(::Type{<:ContinuousUnivariateLogPdf}) - return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), (x) -> 1) + return ContinuousUnivariateLogPdf(DomainSets.FullSpace(), (x) -> 0) end # We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same @@ -211,7 +211,7 @@ function BayesBase.convert_paramfloattype( end function BayesBase.vague(::Type{<:ContinuousMultivariateLogPdf}, dims::Int) - return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, (x) -> 1) + return ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dims, (x) -> 0) end # We do not check typeof of a different functions because in most of the cases lambdas have different types, but they can still be the same diff --git a/test/densities/continouslogpdf_tests.jl b/test/densities/continouslogpdf_tests.jl index b7941cc..a780232 100644 --- a/test/densities/continouslogpdf_tests.jl +++ b/test/densities/continouslogpdf_tests.jl @@ -146,7 +146,7 @@ end d = vague(ContinuousUnivariateLogPdf) @test typeof(d) <: ContinuousUnivariateLogPdf - @test d(rand()) ≈ 1 + @test d(rand()) ≈ 0 end @testitem "ContinuousUnivariateLogPdf: prod" begin @@ -230,31 +230,32 @@ end end @testitem "ContinuousUnivariateLogPdf: convert" begin - d = DomainSets.FullSpace() + import DomainSets: FullSpace + + d = FullSpace() l = (x) -> 1.0 c = convert(ContinuousUnivariateLogPdf, d, l) @test typeof(c) <: ContinuousUnivariateLogPdf - @test isapprox(c, ContinuousUnivariateLogPdf(d, l), atol=1e-12) c2 = convert(ContinuousUnivariateLogPdf, c) @test typeof(c2) <: ContinuousUnivariateLogPdf - @test isapprox(c2, ContinuousUnivariateLogPdf(d, l), atol=1e-12) end @testitem "ContinuousMultivariateLogPdf: Constructor" begin + import DomainSets: FullSpace + f = (x) -> -x'x dist = ContinuousMultivariateLogPdf(2, f) - d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f) + d2 = ContinuousMultivariateLogPdf(FullSpace()^2, f) @test typeof(dist) === typeof(d2) - @test dist ≈ d2 @test paramfloattype(dist) === Float64 @test samplefloattype(dist) === Float64 @test paramfloattype(d2) === Float64 @test samplefloattype(d2) === Float64 - @test_throws AssertionError ContinuousMultivariateLogPdf(DomainSets.FullSpace(), f) + @test_throws AssertionError ContinuousMultivariateLogPdf(FullSpace(), f) @test_throws MethodError ContinuousMultivariateLogPdf(f) end @@ -272,7 +273,9 @@ end end @testitem "ContinuousMultivariateLogPdf: pdf/logpdf" begin - dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> -x'x) + import DomainSets: FullSpace, HalfLine + + dist = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> -x'x) f32_points1 = range(Float32(-10.0), Float32(10.0); length=5) f64_points1 = range(-10.0, 10.0; length=5) @@ -288,7 +291,7 @@ end @test all(map(p -> -p'p == logpdf(dist, p), points1)) @test all(map(p -> exp(-p'p) == pdf(dist, p), points1)) - d2 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> -x'x / 4) + d2 = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> -x'x / 4) f32_points2 = range(Float32(0.0), Float32(10.0); length=5) f64_points2 = range(0.0, 10.0; length=5) @@ -306,9 +309,11 @@ end end @testitem "ContinuousMultivariateLogPdf: test domain in logpdf" begin + import DomainSets: FullSpace, HalfLine + for dim in (2, 3, 4) - dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^dim, (x) -> -x'x) - d2 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^dim, (x) -> -x'x) + dist = ContinuousMultivariateLogPdf(FullSpace()^dim, (x) -> -x'x) + d2 = ContinuousMultivariateLogPdf(HalfLine()^dim, (x) -> -x'x) # This also throws a warning in stdout @test_logs (:warn, r".*incompatible combination.*") @test_throws AssertionError logpdf( @@ -321,89 +326,107 @@ end end @testitem "ContinuousMultivariateLogPdf: vague" begin - d = vague(ContinuousMultivariateLogPdf, 2) + for k in 2:5 + d = vague(ContinuousMultivariateLogPdf, k) - @test typeof(d) <: ContinuousMultivariateLogPdf - @test d ≈ ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0) + @test typeof(d) <: ContinuousMultivariateLogPdf + @test d(rand(k)) ≈ 0 + end end @testitem "ContinuousMultivariateLogPdf: prod" begin - dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0 * -x'x) - d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 3.0 * -x'x) + import DomainSets: FullSpace, HalfLine - product = prod(ProdAnalytical(), dist, d2) + dist = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 2.0 * -x'x) + d2 = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 3.0 * -x'x) + + pr1 = prod(GenericProd(), dist, d2) pt1 = ContinuousMultivariateLogPdf( - DomainSets.FullSpace()^2, (x) -> logpdf(dist, x) + logpdf(d2, x) + FullSpace()^2, (x) -> logpdf(dist, x) + logpdf(d2, x) ) - @test getdomain(product) === getdomain(dist) - @test getdomain(product) === getdomain(d2) - @test variate_form(typeof(product)) === variate_form(typeof(dist)) - @test variate_form(typeof(product)) === variate_form(typeof(d2)) - @test value_support(typeof(product)) === value_support(typeof(dist)) - @test value_support(typeof(product)) === value_support(typeof(d2)) - @test support(product) === support(dist) - @test support(product) === support(d2) - @test isapprox(product, pt1, atol=1e-12) + @test variate_form(typeof(pr1)) === variate_form(typeof(dist)) + @test variate_form(typeof(pr1)) === variate_form(typeof(d2)) + @test value_support(typeof(pr1)) === value_support(typeof(dist)) + @test value_support(typeof(pr1)) === value_support(typeof(d2)) + @test support(pr1) === support(dist) + @test support(pr1) === support(d2) + + for x in [randn(2) for _ in 1:10] + @test isapprox(logpdf(pr1, x), logpdf(pt1, x)) + @test isapprox(pdf(pr1, x), pdf(pt1, x)) + end - result = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 2.0 * -x'x) - d4 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 3.0 * -x'x) + result = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> 2.0 * -x'x) + d4 = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> 3.0 * -x'x) - pr2 = prod(ProdAnalytical(), result, d4) + pr2 = prod(GenericProd(), result, d4) pt2 = ContinuousMultivariateLogPdf( - DomainSets.HalfLine()^2, (x) -> logpdf(result, x) + logpdf(d4, x) + HalfLine()^2, (x) -> logpdf(result, x) + logpdf(d4, x) ) - @test getdomain(pr2) === getdomain(result) - @test getdomain(pr2) === getdomain(d4) @test variate_form(typeof(pr2)) === variate_form(typeof(result)) @test variate_form(typeof(pr2)) === variate_form(typeof(d4)) @test value_support(typeof(pr2)) === value_support(typeof(result)) @test value_support(typeof(pr2)) === value_support(typeof(d4)) @test support(pr2) === support(result) @test support(pr2) === support(d4) - @test isapprox(pr2, pt2, atol=1e-12) - @test !isapprox(product, pr2; atol=1e-12) + for x in [rand(2) for _ in 1:10] + @test isapprox(logpdf(pr2, x), logpdf(pt2, x)) + @test isapprox(pdf(pr2, x), pdf(pt2, x)) + end - d5 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> 2.0 * -x'x) - d6 = ContinuousMultivariateLogPdf(DomainSets.HalfLine()^2, (x) -> 2.0 * -x'x) - @test_throws AssertionError prod(ProdAnalytical(), d5, d6) + d5 = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 2.0 * -x'x) + d6 = ContinuousMultivariateLogPdf(HalfLine()^2, (x) -> 2.0 * -x'x) + @test_throws AssertionError logpdf(prod(GenericProd(), d5, d6), [1.0, -1.0]) # domains are incompatible end @testitem "ContinuousMultivariateLogPdf: vectorised-prod" begin + import DomainSets: FullSpace, HalfLine + f = (x) -> 2.0 * -x'x - dist = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f) - d2 = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, f) - result = ContinuousMultivariateLogPdf(DomainSets.FullSpace()^2, (x) -> f(x) + f(x)) + dist = ContinuousMultivariateLogPdf(FullSpace()^2, f) + result = ContinuousMultivariateLogPdf(FullSpace()^2, (x) -> 3 * f(x)) - product = prod(ProdAnalytical(), dist, d2) + product = prod(GenericProd(), prod(GenericProd(), dist, dist), dist) - @test product isa GenericLogPdfVectorisedProduct - @test getdomain(product) === getdomain(dist) - @test getdomain(product) === getdomain(d2) + @test product isa LinearizedProductOf @test variate_form(typeof(product)) === variate_form(typeof(dist)) - @test variate_form(typeof(product)) === variate_form(typeof(d2)) + @test variate_form(typeof(product)) === variate_form(typeof(result)) @test value_support(typeof(product)) === value_support(typeof(dist)) - @test value_support(typeof(product)) === value_support(typeof(d2)) + @test value_support(typeof(product)) === value_support(typeof(result)) @test support(product) === support(dist) - @test support(product) === support(d2) + @test support(product) === support(result) + + for x in [rand(2) for _ in 1:10] + @test pdf(product, x) ≈ pdf(result, x) + @test logpdf(product, x) ≈ logpdf(result, x) + end - for point in [rand(Float64, 2) for _ in 1:10] - @test pdf(product, point) ≈ pdf(result, point) - @test logpdf(product, point) ≈ logpdf(result, point) + # Test internal side-effects + another_product = prod(GenericProd(), product, dist) + + for x in [rand(2) for _ in 1:10] + @test logpdf(product, x) ≈ logpdf(result, x) + @test pdf(product, x) ≈ pdf(result, x) + + @test logpdf(another_product, x) ≈ (logpdf(product, x) + logpdf(dist, x)) + @test pdf(another_product, x) ≈ (pdf(product, x) * pdf(dist, x)) end end @testitem "ContinuousMultivariateLogPdf: convert" begin - d = DomainSets.FullSpace()^2 - l = (x) -> 1.0 + import DomainSets: FullSpace, HalfLine + + for k in 2:5 + d = FullSpace()^k + l = (x) -> 0 - c = convert(ContinuousMultivariateLogPdf, d, l) - @test typeof(c) <: ContinuousMultivariateLogPdf - @test isapprox(c, ContinuousMultivariateLogPdf(d, l), atol=1e-12) + c = convert(ContinuousMultivariateLogPdf, d, l) + @test typeof(c) <: ContinuousMultivariateLogPdf - c2 = convert(ContinuousMultivariateLogPdf, c) - @test typeof(c2) <: ContinuousMultivariateLogPdf - @test isapprox(c2, ContinuousMultivariateLogPdf(d, l), atol=1e-12) + c2 = convert(ContinuousMultivariateLogPdf, c) + @test typeof(c2) <: ContinuousMultivariateLogPdf + end end \ No newline at end of file