diff --git a/test/distributions_setuptests.jl b/test/distributions_setuptests.jl index 256c2472..2b2fd047 100644 --- a/test/distributions_setuptests.jl +++ b/test/distributions_setuptests.jl @@ -11,7 +11,8 @@ import ExponentialFamily: promote_samplefloattype, paramfloattype, convert_paramfloattype, - FactorizedJoint + FactorizedJoint, + PromoteTypeConverter function generate_random_distributions(::Type{V} = Any; seed = abs(rand(Int)), Types = (Float32, Float64)) where {V} rng = StableRNG(seed) diff --git a/test/distributions_tests.jl b/test/distributions_tests.jl index c8c4c04c..0820c40c 100644 --- a/test/distributions_tests.jl +++ b/test/distributions_tests.jl @@ -154,3 +154,13 @@ end ) end end + +@testitem "TypeConverter" begin + include("./distributions_setuptests.jl") + + for original_T in (Float16, Float32, Float64), target_T in (Float16, Float32, Float64), n in (1, 2, 3) + converter = PromoteTypeConverter(target_T, convert) + + @test typeof(@inferred(converter(rand(original_T)))) === target_T + end +end \ No newline at end of file